@genai-fi/nanogpt 0.4.0 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +3 -3
- package/dist/NanoGPTModel.js +83 -70
- package/dist/TeachableLLM.js +1 -1
- package/dist/{random_width-CMHmdbSu.js → TiedEmbedding-CnJ1bx4q.js} +760 -719
- package/dist/{axis_util-DeydwOoC.js → axis_util-BgTGy5w8.js} +1 -1
- package/dist/{concat-DS_qH7MI.js → concat-CuRsVY-K.js} +1 -1
- package/dist/dropout-DfDdklfL.js +193 -0
- package/dist/{gather-BUmJIS8n.js → gather-ZYRWhmXR.js} +1 -1
- package/dist/gelu-CnCt17Lk.js +26 -0
- package/dist/{index-XjBAhiFO.js → index-C4JCoBvj.js} +61 -61
- package/dist/kernel_funcs_utils-CAd1h9X1.js +388 -0
- package/dist/layers/CausalSelfAttention.js +73 -72
- package/dist/layers/MLP.d.ts +3 -1
- package/dist/layers/MLP.js +93 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +6 -46
- package/dist/layers/TransformerBlock.js +2 -2
- package/dist/{log_sum_exp-DJPkVZZn.js → log_sum_exp-BswFnwOb.js} +5 -5
- package/dist/main.js +1 -1
- package/dist/{mat_mul-CKwFEV1Q.js → mat_mul-415y5Qn2.js} +1 -1
- package/dist/{max-DJvEiCAJ.js → max-CP_9O2Yd.js} +1 -1
- package/dist/{moments-CrWRPcR3.js → moments-CjeIaVdp.js} +3 -3
- package/dist/{norm-BzY929B_.js → norm-CZM380I3.js} +5 -5
- package/dist/{ones-BO01zpJG.js → ones-Bf3YR48P.js} +2 -2
- package/dist/ops/appendCache.js +1 -1
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +13 -9
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +3 -3
- package/dist/ops/cpu/gelu.d.ts +1 -0
- package/dist/ops/cpu/gelu.js +40 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +4 -4
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.d.ts +3 -0
- package/dist/ops/gelu.js +8 -0
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.d.ts +2 -0
- package/dist/ops/grads/gelu.js +5 -0
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +19 -18
- package/dist/ops/webgl/fusedSoftmax.js +489 -788
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.d.ts +2 -0
- package/dist/ops/webgl/gelu.js +50 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{range-DQMNzBWs.js → range-9AzeApCc.js} +1 -1
- package/dist/{reshape-DFzh97Sc.js → reshape-Boe4DuIO.js} +1 -1
- package/dist/{sin-BYM-U4Ut.js → sin-KmhiDuMa.js} +1 -1
- package/dist/{slice_util-CnVNPQI-.js → slice_util-19zDNNSn.js} +2 -2
- package/dist/{softmax-4DOn6cPq.js → softmax-Cujsg4ay.js} +1 -1
- package/dist/{split-CkbeVdF8.js → split-DbcNm1-i.js} +1 -1
- package/dist/{stack-DaIMO5iX.js → stack-D1YjmgKN.js} +1 -1
- package/dist/{sum-C6u3xMi3.js → sum-R28pucR5.js} +1 -1
- package/dist/{tensor-Cu1fU7H7.js → tensor-BVeHdl7V.js} +1 -1
- package/dist/{tensor2d-D0CKdG6B.js → tensor2d-DqFGNs_K.js} +1 -1
- package/dist/{tfjs_backend-Bzl2SrRo.js → tfjs_backend-Cug-PH75.js} +826 -1015
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +3 -3
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +5 -5
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BS4AKqNU.js → variable-LJT9Ld63.js} +1 -1
- package/dist/{zeros-CmJFiC84.js → zeros-dnQxFgAD.js} +1 -1
- package/package.json +1 -1
- package/dist/MLP-KHhikThU.js +0 -83
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
import { aj as E, ak as D, af as B, al as w, n as N, am as v } from "./index-C4JCoBvj.js";
|
|
2
|
+
/**
|
|
3
|
+
* @license
|
|
4
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
* =============================================================================
|
|
17
|
+
*/
|
|
18
|
+
function C(t) {
|
|
19
|
+
try {
|
|
20
|
+
return t.map((e) => E(e));
|
|
21
|
+
} catch (e) {
|
|
22
|
+
throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${e}`);
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
function F(t) {
|
|
26
|
+
return t.map((e) => D(e));
|
|
27
|
+
}
|
|
28
|
+
/**
|
|
29
|
+
* @license
|
|
30
|
+
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
31
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
32
|
+
* you may not use this file except in compliance with the License.
|
|
33
|
+
* You may obtain a copy of the License at
|
|
34
|
+
*
|
|
35
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
36
|
+
*
|
|
37
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
38
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
39
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
40
|
+
* See the License for the specific language governing permissions and
|
|
41
|
+
* limitations under the License.
|
|
42
|
+
* =============================================================================
|
|
43
|
+
*/
|
|
44
|
+
function R(t) {
|
|
45
|
+
if (t <= 1)
|
|
46
|
+
return "int";
|
|
47
|
+
if (t === 2)
|
|
48
|
+
return "ivec2";
|
|
49
|
+
if (t === 3)
|
|
50
|
+
return "ivec3";
|
|
51
|
+
if (t === 4)
|
|
52
|
+
return "ivec4";
|
|
53
|
+
if (t === 5)
|
|
54
|
+
return "ivec5";
|
|
55
|
+
if (t === 6)
|
|
56
|
+
return "ivec6";
|
|
57
|
+
throw Error(`GPU for rank ${t} is not yet supported`);
|
|
58
|
+
}
|
|
59
|
+
/**
|
|
60
|
+
* @license
|
|
61
|
+
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
62
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
63
|
+
* you may not use this file except in compliance with the License.
|
|
64
|
+
* You may obtain a copy of the License at
|
|
65
|
+
*
|
|
66
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
67
|
+
*
|
|
68
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
69
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
70
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
71
|
+
* See the License for the specific language governing permissions and
|
|
72
|
+
* limitations under the License.
|
|
73
|
+
* =============================================================================
|
|
74
|
+
*/
|
|
75
|
+
function y(t) {
|
|
76
|
+
return B().getBool("WEBGL_USE_SHAPES_UNIFORMS") && t <= 4;
|
|
77
|
+
}
|
|
78
|
+
/**
|
|
79
|
+
* @license
|
|
80
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
81
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
82
|
+
* you may not use this file except in compliance with the License.
|
|
83
|
+
* You may obtain a copy of the License at
|
|
84
|
+
*
|
|
85
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
86
|
+
*
|
|
87
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
88
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
89
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
90
|
+
* See the License for the specific language governing permissions and
|
|
91
|
+
* limitations under the License.
|
|
92
|
+
* =============================================================================
|
|
93
|
+
*/
|
|
94
|
+
function k(t, e) {
|
|
95
|
+
return ["x", "y", "z", "w", "u", "v"].slice(0, e).map((a) => `${t}.${a}`);
|
|
96
|
+
}
|
|
97
|
+
function _(t, e) {
|
|
98
|
+
return e === 1 ? [t] : k(t, e);
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* @license
|
|
102
|
+
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
103
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
104
|
+
* you may not use this file except in compliance with the License.
|
|
105
|
+
* You may obtain a copy of the License at
|
|
106
|
+
*
|
|
107
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
108
|
+
*
|
|
109
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
110
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
111
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
112
|
+
* See the License for the specific language governing permissions and
|
|
113
|
+
* limitations under the License.
|
|
114
|
+
* =============================================================================
|
|
115
|
+
*/
|
|
116
|
+
class A {
|
|
117
|
+
constructor(e, a, u) {
|
|
118
|
+
this.variableNames = ["A", "B"], this.outputShape = w(a, u), this.enableShapeUniforms = y(this.outputShape.length), this.userCode = `
|
|
119
|
+
float binaryOperation(float a, float b) {
|
|
120
|
+
${e}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
void main() {
|
|
124
|
+
float a = getAAtOutCoords();
|
|
125
|
+
float b = getBAtOutCoords();
|
|
126
|
+
setOutput(binaryOperation(a, b));
|
|
127
|
+
}
|
|
128
|
+
`;
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
/**
|
|
132
|
+
* @license
|
|
133
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
134
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
135
|
+
* you may not use this file except in compliance with the License.
|
|
136
|
+
* You may obtain a copy of the License at
|
|
137
|
+
*
|
|
138
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
139
|
+
*
|
|
140
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
141
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
142
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
143
|
+
* See the License for the specific language governing permissions and
|
|
144
|
+
* limitations under the License.
|
|
145
|
+
* =============================================================================
|
|
146
|
+
*/
|
|
147
|
+
class z {
|
|
148
|
+
constructor(e, a, u, d = !1) {
|
|
149
|
+
this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(a, u);
|
|
150
|
+
const o = this.outputShape.length;
|
|
151
|
+
this.enableShapeUniforms = y(o);
|
|
152
|
+
let n = "";
|
|
153
|
+
if (d)
|
|
154
|
+
if (o === 0 || N(this.outputShape) === 1)
|
|
155
|
+
n = `
|
|
156
|
+
result.y = 0.;
|
|
157
|
+
result.z = 0.;
|
|
158
|
+
result.w = 0.;
|
|
159
|
+
`;
|
|
160
|
+
else if (n = `
|
|
161
|
+
${R(o)} coords = getOutputCoords();
|
|
162
|
+
`, o === 1)
|
|
163
|
+
this.enableShapeUniforms ? n += `
|
|
164
|
+
result.y = (coords + 1) >= outShape ? 0. : result.y;
|
|
165
|
+
result.z = 0.;
|
|
166
|
+
result.w = 0.;
|
|
167
|
+
` : n += `
|
|
168
|
+
result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
|
|
169
|
+
result.z = 0.;
|
|
170
|
+
result.w = 0.;
|
|
171
|
+
`;
|
|
172
|
+
else {
|
|
173
|
+
const s = _("coords", o);
|
|
174
|
+
this.enableShapeUniforms ? n += `
|
|
175
|
+
bool nextRowOutOfBounds =
|
|
176
|
+
(${s[o - 2]} + 1) >= outShape[${o} - 2];
|
|
177
|
+
bool nextColOutOfBounds =
|
|
178
|
+
(${s[o - 1]} + 1) >= outShape[${o} - 1];
|
|
179
|
+
result.y = nextColOutOfBounds ? 0. : result.y;
|
|
180
|
+
result.z = nextRowOutOfBounds ? 0. : result.z;
|
|
181
|
+
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
|
182
|
+
` : n += `
|
|
183
|
+
bool nextRowOutOfBounds =
|
|
184
|
+
(${s[o - 2]} + 1) >= ${this.outputShape[o - 2]};
|
|
185
|
+
bool nextColOutOfBounds =
|
|
186
|
+
(${s[o - 1]} + 1) >= ${this.outputShape[o - 1]};
|
|
187
|
+
result.y = nextColOutOfBounds ? 0. : result.y;
|
|
188
|
+
result.z = nextRowOutOfBounds ? 0. : result.z;
|
|
189
|
+
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
|
190
|
+
`;
|
|
191
|
+
}
|
|
192
|
+
this.userCode = `
|
|
193
|
+
vec4 binaryOperation(vec4 a, vec4 b) {
|
|
194
|
+
${e}
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
void main() {
|
|
198
|
+
vec4 a = getAAtOutCoords();
|
|
199
|
+
vec4 b = getBAtOutCoords();
|
|
200
|
+
|
|
201
|
+
vec4 result = binaryOperation(a, b);
|
|
202
|
+
${n}
|
|
203
|
+
|
|
204
|
+
setOutput(result);
|
|
205
|
+
}
|
|
206
|
+
`;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
/**
|
|
210
|
+
* @license
|
|
211
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
212
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
213
|
+
* you may not use this file except in compliance with the License.
|
|
214
|
+
* You may obtain a copy of the License at
|
|
215
|
+
*
|
|
216
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
217
|
+
*
|
|
218
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
219
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
220
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
221
|
+
* See the License for the specific language governing permissions and
|
|
222
|
+
* limitations under the License.
|
|
223
|
+
* =============================================================================
|
|
224
|
+
*/
|
|
225
|
+
function P(t) {
|
|
226
|
+
const { inputs: e, backend: a } = t, { x: u } = e;
|
|
227
|
+
return a.incRef(u.dataId), { dataId: u.dataId, shape: u.shape, dtype: u.dtype };
|
|
228
|
+
}
|
|
229
|
+
/**
|
|
230
|
+
* @license
|
|
231
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
232
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
233
|
+
* you may not use this file except in compliance with the License.
|
|
234
|
+
* You may obtain a copy of the License at
|
|
235
|
+
*
|
|
236
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
237
|
+
*
|
|
238
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
239
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
240
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
241
|
+
* See the License for the specific language governing permissions and
|
|
242
|
+
* limitations under the License.
|
|
243
|
+
* =============================================================================
|
|
244
|
+
*/
|
|
245
|
+
function G(t) {
|
|
246
|
+
const { inputs: e, backend: a } = t, { real: u, imag: d } = e, o = a.makeTensorInfo(u.shape, "complex64"), n = a.texData.get(o.dataId), l = P({ inputs: { x: u }, backend: a }), s = P({ inputs: { x: d }, backend: a });
|
|
247
|
+
return n.complexTensorInfos = { real: l, imag: s }, o;
|
|
248
|
+
}
|
|
249
|
+
/**
|
|
250
|
+
* @license
|
|
251
|
+
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
252
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
253
|
+
* you may not use this file except in compliance with the License.
|
|
254
|
+
* You may obtain a copy of the License at
|
|
255
|
+
*
|
|
256
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
257
|
+
*
|
|
258
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
259
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
260
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
261
|
+
* See the License for the specific language governing permissions and
|
|
262
|
+
* limitations under the License.
|
|
263
|
+
* =============================================================================
|
|
264
|
+
*/
|
|
265
|
+
class V {
|
|
266
|
+
constructor(e, a) {
|
|
267
|
+
this.variableNames = ["A"], this.outputShape = e, this.enableShapeUniforms = y(this.outputShape.length), this.userCode = `
|
|
268
|
+
float unaryOperation(float x) {
|
|
269
|
+
${a}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
void main() {
|
|
273
|
+
float x = getAAtOutCoords();
|
|
274
|
+
float y = unaryOperation(x);
|
|
275
|
+
|
|
276
|
+
setOutput(y);
|
|
277
|
+
}
|
|
278
|
+
`;
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
const H = "if (isnan(x)) return x;";
|
|
282
|
+
/**
|
|
283
|
+
* @license
|
|
284
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
285
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
286
|
+
* you may not use this file except in compliance with the License.
|
|
287
|
+
* You may obtain a copy of the License at
|
|
288
|
+
*
|
|
289
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
290
|
+
*
|
|
291
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
292
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
293
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
294
|
+
* See the License for the specific language governing permissions and
|
|
295
|
+
* limitations under the License.
|
|
296
|
+
* =============================================================================
|
|
297
|
+
*/
|
|
298
|
+
class L {
|
|
299
|
+
constructor(e, a) {
|
|
300
|
+
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = e, this.enableShapeUniforms = y(this.outputShape.length), this.userCode = `
|
|
301
|
+
vec4 unaryOperation(vec4 x) {
|
|
302
|
+
${a}
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
void main() {
|
|
306
|
+
vec4 x = getAAtOutCoords();
|
|
307
|
+
vec4 y = unaryOperation(x);
|
|
308
|
+
|
|
309
|
+
setOutput(y);
|
|
310
|
+
}
|
|
311
|
+
`;
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
/**
|
|
315
|
+
* @license
|
|
316
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
317
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
318
|
+
* you may not use this file except in compliance with the License.
|
|
319
|
+
* You may obtain a copy of the License at
|
|
320
|
+
*
|
|
321
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
322
|
+
*
|
|
323
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
324
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
325
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
326
|
+
* See the License for the specific language governing permissions and
|
|
327
|
+
* limitations under the License.
|
|
328
|
+
* =============================================================================
|
|
329
|
+
*/
|
|
330
|
+
function K({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: a, dtype: u }) {
|
|
331
|
+
return ({ inputs: d, backend: o }) => {
|
|
332
|
+
const { x: n } = d, l = o, s = u || n.dtype;
|
|
333
|
+
if (l.shouldExecuteOnCPU([n]) && a != null) {
|
|
334
|
+
const p = l.texData.get(n.dataId), x = a(p.values, s);
|
|
335
|
+
return l.makeTensorInfo(n.shape, s, x);
|
|
336
|
+
}
|
|
337
|
+
const i = B().getBool("WEBGL_PACK_UNARY_OPERATIONS") && e != null;
|
|
338
|
+
let r;
|
|
339
|
+
return i ? r = new L(n.shape, e) : r = new V(n.shape, t), l.runWebGLProgram(r, [n], s);
|
|
340
|
+
};
|
|
341
|
+
}
|
|
342
|
+
function Y({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: a = !1, supportsComplex: u = !1, cpuKernelImpl: d, dtype: o }) {
|
|
343
|
+
return ({ inputs: n, backend: l }) => {
|
|
344
|
+
const { a: s, b: i } = n, r = l;
|
|
345
|
+
if (u && s.dtype === "complex64") {
|
|
346
|
+
const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [g, m] = [
|
|
347
|
+
[h.complexTensorInfos.real, f.complexTensorInfos.real],
|
|
348
|
+
[h.complexTensorInfos.imag, f.complexTensorInfos.imag]
|
|
349
|
+
].map((S) => {
|
|
350
|
+
const [c, O] = S, $ = {
|
|
351
|
+
dataId: c.dataId,
|
|
352
|
+
dtype: c.dtype,
|
|
353
|
+
shape: s.shape
|
|
354
|
+
}, T = {
|
|
355
|
+
dataId: O.dataId,
|
|
356
|
+
dtype: O.dtype,
|
|
357
|
+
shape: i.shape
|
|
358
|
+
}, U = new A(t, s.shape, i.shape);
|
|
359
|
+
return r.runWebGLProgram(U, [$, T], v(c.dtype, O.dtype));
|
|
360
|
+
}), I = G({ inputs: { real: g, imag: m }, backend: r });
|
|
361
|
+
return r.disposeIntermediateTensorInfo(g), r.disposeIntermediateTensorInfo(m), I;
|
|
362
|
+
}
|
|
363
|
+
const p = o || v(s.dtype, i.dtype);
|
|
364
|
+
if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && d != null) {
|
|
365
|
+
const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, g = s.dtype === "string" ? (
|
|
366
|
+
// tslint:disable-next-line: no-any
|
|
367
|
+
C(h)
|
|
368
|
+
) : h, m = s.dtype === "string" ? (
|
|
369
|
+
// tslint:disable-next-line: no-any
|
|
370
|
+
C(f)
|
|
371
|
+
) : f, [I, S] = d(s.shape, i.shape, g, m, p), c = r.makeTensorInfo(S, p), O = r.texData.get(c.dataId);
|
|
372
|
+
return O.values = I, c;
|
|
373
|
+
}
|
|
374
|
+
const x = B().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
|
|
375
|
+
let b;
|
|
376
|
+
return x ? b = new z(e, s.shape, i.shape, a) : b = new A(t, s.shape, i.shape), r.runWebGLProgram(b, [s, i], p);
|
|
377
|
+
};
|
|
378
|
+
}
|
|
379
|
+
export {
|
|
380
|
+
H as C,
|
|
381
|
+
F as a,
|
|
382
|
+
y as b,
|
|
383
|
+
k as c,
|
|
384
|
+
Y as d,
|
|
385
|
+
C as f,
|
|
386
|
+
R as g,
|
|
387
|
+
K as u
|
|
388
|
+
};
|
|
@@ -1,17 +1,18 @@
|
|
|
1
|
-
import { attentionMask as
|
|
2
|
-
import
|
|
3
|
-
import { qkv as
|
|
4
|
-
import { rope as
|
|
1
|
+
import { attentionMask as P } from "../ops/attentionMask.js";
|
|
2
|
+
import T from "./BaseLayer.js";
|
|
3
|
+
import { qkv as y } from "../ops/qkv.js";
|
|
4
|
+
import { rope as w } from "../ops/rope.js";
|
|
5
5
|
import { appendCache as E } from "../ops/appendCache.js";
|
|
6
|
-
import { D as
|
|
7
|
-
import { fusedSoftmax as
|
|
8
|
-
import { l as
|
|
9
|
-
import { o as
|
|
10
|
-
import { z as
|
|
11
|
-
import { v as
|
|
12
|
-
import {
|
|
13
|
-
import { r as
|
|
14
|
-
|
|
6
|
+
import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index-C4JCoBvj.js";
|
|
7
|
+
import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
|
|
8
|
+
import { l as W, w as M, d as x } from "../tfjs_backend-Cug-PH75.js";
|
|
9
|
+
import { o as N } from "../ones-Bf3YR48P.js";
|
|
10
|
+
import { z as q } from "../zeros-dnQxFgAD.js";
|
|
11
|
+
import { v as k } from "../variable-LJT9Ld63.js";
|
|
12
|
+
import { r as C, d as I } from "../dropout-DfDdklfL.js";
|
|
13
|
+
import { r as B } from "../reshape-Boe4DuIO.js";
|
|
14
|
+
import { m as F } from "../mat_mul-415y5Qn2.js";
|
|
15
|
+
class nt extends T {
|
|
15
16
|
cAttn = null;
|
|
16
17
|
cProj = null;
|
|
17
18
|
bias;
|
|
@@ -22,17 +23,17 @@ class nt extends y {
|
|
|
22
23
|
units;
|
|
23
24
|
projUnits;
|
|
24
25
|
constructor(t, s) {
|
|
25
|
-
super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias =
|
|
26
|
-
const e =
|
|
27
|
-
this.maskInf =
|
|
26
|
+
super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(N([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
|
|
27
|
+
const e = q([s.gpt.blockSize, s.gpt.blockSize]), i = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
|
|
28
|
+
this.maskInf = M(this.bias, e, i);
|
|
28
29
|
}
|
|
29
30
|
build() {
|
|
30
|
-
this.cAttn === null && (this.cAttn =
|
|
31
|
-
|
|
31
|
+
this.cAttn === null && (this.cAttn = k(
|
|
32
|
+
C([this.config.gpt.nEmbed, this.units], 0, 0.02),
|
|
32
33
|
!0
|
|
33
34
|
//`block_${this.index}_attn_cAttn_kernel`
|
|
34
|
-
)), this.cProj === null && (this.cProj =
|
|
35
|
-
|
|
35
|
+
)), this.cProj === null && (this.cProj = k(
|
|
36
|
+
C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
|
|
36
37
|
!0
|
|
37
38
|
//`block_${this.index}_attn_cProj_kernel`
|
|
38
39
|
));
|
|
@@ -55,57 +56,54 @@ class nt extends y {
|
|
|
55
56
|
const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
|
|
56
57
|
if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
|
|
57
58
|
if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
|
|
58
|
-
this.cAttn ? this.cAttn.assign(s) : this.cAttn =
|
|
59
|
+
this.cAttn ? this.cAttn.assign(s) : this.cAttn = k(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = k(e, !0);
|
|
59
60
|
}
|
|
60
|
-
getAttentionScores(t, s, e,
|
|
61
|
-
const o =
|
|
62
|
-
return
|
|
61
|
+
getAttentionScores(t, s, e, i) {
|
|
62
|
+
const o = P(t, s, this.divisor, this.maskInf);
|
|
63
|
+
return _(o, e ? this.config.gpt.dropout : 0, i);
|
|
63
64
|
}
|
|
64
65
|
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
65
|
-
getAttentionScoresWithPast(t, s, e
|
|
66
|
-
const i = t.
|
|
67
|
-
|
|
68
|
-
if (i > 1 && n > 0)
|
|
69
|
-
throw new Error("Cannot use past with T_cur > 1");
|
|
70
|
-
if (i > 1) {
|
|
71
|
-
const c = this.maskInf.slice([0, 0], [i, i]).expandDims(0).expandDims(0);
|
|
72
|
-
r = r.add(c);
|
|
73
|
-
}
|
|
74
|
-
return S(r, e ? this.config.gpt.dropout : 0, o);
|
|
66
|
+
getAttentionScoresWithPast(t, s, e) {
|
|
67
|
+
const i = P(t, s, this.divisor, void 0, e);
|
|
68
|
+
return _(i, 0, 0);
|
|
75
69
|
}
|
|
76
70
|
getQKV(t) {
|
|
77
|
-
return
|
|
71
|
+
return y(t, this.cAttn, this.config.gpt.nHead);
|
|
78
72
|
}
|
|
79
73
|
getOutputProjection(t) {
|
|
80
|
-
const s = t.shape[0], e = t.shape[2],
|
|
81
|
-
return
|
|
74
|
+
const s = t.shape[0], e = t.shape[2], i = this.config.gpt.nEmbed, o = t.transpose([0, 2, 1, 3]), n = B(o, [s, e, i]);
|
|
75
|
+
return x(n, this.cProj);
|
|
82
76
|
}
|
|
83
77
|
updateCache(t, s, e) {
|
|
84
|
-
const
|
|
78
|
+
const i = this.config.gpt.blockSize, o = t.shape[2], n = Math.min(e?.length || 0, i - o), r = e ? E(e.k, t, i) : t, a = e ? E(e.v, s, i) : s;
|
|
85
79
|
return {
|
|
86
|
-
k:
|
|
87
|
-
v:
|
|
88
|
-
length:
|
|
80
|
+
k: S(r),
|
|
81
|
+
v: S(a),
|
|
82
|
+
length: n + o,
|
|
89
83
|
cumulativeLength: e ? e.cumulativeLength + o : o
|
|
90
84
|
};
|
|
91
85
|
}
|
|
92
|
-
forward(t, s = !1, e,
|
|
93
|
-
return
|
|
86
|
+
forward(t, s = !1, e, i = !1, o) {
|
|
87
|
+
return $(() => {
|
|
94
88
|
this.startMemory();
|
|
95
|
-
const [
|
|
96
|
-
|
|
97
|
-
const
|
|
98
|
-
o && (
|
|
99
|
-
let
|
|
100
|
-
|
|
101
|
-
const
|
|
102
|
-
|
|
89
|
+
const [n, r, a] = this.getQKV(t), p = o ? o.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, f = c ? w(r, c, p) : r;
|
|
90
|
+
c && (n.dispose(), r.dispose());
|
|
91
|
+
const g = o ? o.length : 0, d = this.updateCache(f, a, o), l = d.k, m = d.v;
|
|
92
|
+
o && (f.dispose(), a.dispose());
|
|
93
|
+
let h;
|
|
94
|
+
g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s, e), u.dispose(), s && l.dispose();
|
|
95
|
+
const b = F(h, m);
|
|
96
|
+
i || h.dispose(), s && m.dispose();
|
|
97
|
+
const A = this.getOutputProjection(b);
|
|
98
|
+
b.dispose();
|
|
99
|
+
const v = i ? h.mean(1) : void 0;
|
|
100
|
+
return this.endMemory("CausalSelfAttention"), { output: A, attention: v, presentKV: s ? void 0 : d };
|
|
103
101
|
});
|
|
104
102
|
}
|
|
105
|
-
call(t, s = !1, e = !1,
|
|
106
|
-
if (
|
|
103
|
+
call(t, s = !1, e = !1, i) {
|
|
104
|
+
if (i && !this.config.gpt.useRope)
|
|
107
105
|
throw new Error("Cannot use pastKV without RoPE enabled");
|
|
108
|
-
if (s &&
|
|
106
|
+
if (s && i)
|
|
109
107
|
throw new Error("Cannot use pastKV during training");
|
|
110
108
|
if (t.shape.length !== 3)
|
|
111
109
|
throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
|
|
@@ -114,30 +112,33 @@ class nt extends y {
|
|
|
114
112
|
this.build();
|
|
115
113
|
const o = Math.random() * 1e9;
|
|
116
114
|
if (s && this.config.layerConfig.checkpointAttention) {
|
|
117
|
-
const
|
|
115
|
+
const r = L(
|
|
118
116
|
// @ts-expect-error Invalid params
|
|
119
|
-
(
|
|
120
|
-
const
|
|
121
|
-
|
|
122
|
-
const
|
|
123
|
-
const [
|
|
124
|
-
|
|
125
|
-
const
|
|
126
|
-
|
|
127
|
-
return j.presentKV?.k.dispose(), j.presentKV?.v.dispose(), j.output;
|
|
128
|
-
})([b, c, h], l);
|
|
129
|
-
return v().state.activeTape = u, k;
|
|
117
|
+
(a, p, c, u) => {
|
|
118
|
+
const f = this.forward(a, !0, o);
|
|
119
|
+
u([a]);
|
|
120
|
+
const g = (d, l) => {
|
|
121
|
+
const [m] = l, h = j().state.activeTape;
|
|
122
|
+
j().state.activeTape = [];
|
|
123
|
+
const b = O((A, v, R) => this.forward(A, !0, o).output)([m, p, c], d);
|
|
124
|
+
return j().state.activeTape = h, b;
|
|
130
125
|
};
|
|
131
|
-
return { value:
|
|
126
|
+
return { value: f.output, gradFunc: g };
|
|
132
127
|
}
|
|
133
128
|
)(t, this.cAttn, this.cProj);
|
|
134
129
|
if (this.config.gpt.dropout > 0) {
|
|
135
|
-
const
|
|
136
|
-
return
|
|
130
|
+
const a = I(r, this.config.gpt.dropout);
|
|
131
|
+
return r.dispose(), { output: a };
|
|
132
|
+
} else
|
|
133
|
+
return { output: r };
|
|
134
|
+
} else {
|
|
135
|
+
const n = this.forward(t, s, o, e, i);
|
|
136
|
+
if (this.config.gpt.dropout > 0) {
|
|
137
|
+
const r = I(n.output, this.config.gpt.dropout);
|
|
138
|
+
return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
|
|
137
139
|
} else
|
|
138
|
-
return
|
|
139
|
-
}
|
|
140
|
-
return this.forward(t, s, o, e, n);
|
|
140
|
+
return n;
|
|
141
|
+
}
|
|
141
142
|
}
|
|
142
143
|
dispose() {
|
|
143
144
|
this.cAttn?.dispose(), this.cProj?.dispose(), this.bias.dispose(), this.maskInf.dispose();
|
package/dist/layers/MLP.d.ts
CHANGED
|
@@ -3,15 +3,17 @@ import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
|
|
|
3
3
|
export default class MLP extends BaseLayer {
|
|
4
4
|
private cFc;
|
|
5
5
|
private cProj;
|
|
6
|
-
private dropout;
|
|
7
6
|
private index;
|
|
8
7
|
private _trainable;
|
|
8
|
+
private hiddenUnits;
|
|
9
9
|
constructor(index: number, config: GPTLayerConfig);
|
|
10
|
+
private build;
|
|
10
11
|
get variables(): Variable[];
|
|
11
12
|
get trainable(): boolean;
|
|
12
13
|
set trainable(value: boolean);
|
|
13
14
|
saveWeights(map: Map<string, Tensor[]>): void;
|
|
14
15
|
loadWeights(weights: Map<string, Tensor[]>): void;
|
|
16
|
+
forward(x: Tensor): Tensor;
|
|
15
17
|
call(x: Tensor, training?: boolean): Tensor;
|
|
16
18
|
dispose(): void;
|
|
17
19
|
}
|