@genai-fi/nanogpt 0.6.0 → 0.6.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +7 -7
- package/dist/NanoGPTModel.js +70 -121
- package/dist/RealDiv-BYViZwhN.js +540 -0
- package/dist/Reshape-t7Kcikjk.js +127 -0
- package/dist/TeachableLLM.d.ts +2 -0
- package/dist/TeachableLLM.js +34 -27
- package/dist/{TiedEmbedding-BhxWO8QR.js → TiedEmbedding-9WeDwvjO.js} +12 -13
- package/dist/{axis_util-D17qZRQm.js → axis_util-Bu4h7XWV.js} +14 -12
- package/dist/{broadcast_to-BMQLjvt_.js → broadcast_to-DARN-DBD.js} +2 -2
- package/dist/{concat-DhZfF1GY.js → concat-5aPGqw3Z.js} +3 -3
- package/dist/{dataset-oilnemHf.js → dataset-pgqp-YfL.js} +3 -3
- package/dist/{dropout-CrMQPCeG.js → dropout-Bciw46HT.js} +7 -7
- package/dist/{gather-DZCMHZuN.js → gather-DjyCjmOD.js} +1 -1
- package/dist/gpgpu_math-CNslybmD.js +3115 -0
- package/dist/{index-bMBtI-WR.js → index-BAzbokzv.js} +846 -649
- package/dist/{kernel_funcs_utils-CNmjLWnB.js → kernel_funcs_utils-CUxJCg0g.js} +232 -138
- package/dist/layers/BaseLayer.js +2 -2
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +13 -33
- package/dist/layers/TiedEmbedding.js +6 -7
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +13 -0
- package/dist/loader/load.js +27 -0
- package/dist/loader/loadHF.d.ts +7 -0
- package/dist/loader/loadHF.js +22 -0
- package/dist/{utilities/load.d.ts → loader/loadTransformers.d.ts} +11 -11
- package/dist/loader/loadTransformers.js +28 -0
- package/dist/loader/newZipLoad.d.ts +8 -0
- package/dist/loader/newZipLoad.js +21 -0
- package/dist/loader/oldZipLoad.d.ts +7 -0
- package/dist/loader/oldZipLoad.js +76 -0
- package/dist/{log_sum_exp-BHdkCb4s.js → log_sum_exp-YEo2h3gb.js} +14 -14
- package/dist/main.js +23 -20
- package/dist/{mat_mul-BsrLfy81.js → mat_mul-7121rsJk.js} +1 -1
- package/dist/{max-DechV4Bc.js → max-DtlIuVeW.js} +1 -1
- package/dist/mulmat_packed_gpu-D4nKF7Je.js +71 -0
- package/dist/{norm-B9hWHZH1.js → norm-CzltS9Fz.js} +16 -16
- package/dist/{ones-g0K8jVwm.js → ones-BBlSRqn1.js} +2 -2
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +6 -6
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +9 -9
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +17 -48
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +4 -4
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +8 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +29 -560
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +46 -113
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{ops-Mv7Ta72x.js → ops-C0sQEcPw.js} +117 -109
- package/dist/{random_width-BBAWzDym.js → random_width-DWzaOgrn.js} +6925 -6291
- package/dist/{range-DMaG9A3G.js → range-DYsrnfiy.js} +1 -1
- package/dist/{gpgpu_math-Ctc31slO.js → reciprocal-CJQeasVa.js} +7 -5
- package/dist/register_all_kernels-BfFCQAqs.js +21397 -0
- package/dist/{reshape-T4yDEqoF.js → reshape-krWGKraP.js} +1 -1
- package/dist/scatter_nd_util-93ln7Hut.js +46 -0
- package/dist/selu_util-sntGesxr.js +740 -0
- package/dist/{shared-XNAoXhOa.js → shared-Ca6iDobD.js} +1462 -1089
- package/dist/{sin-EEhbrRO_.js → sin-D_h-qCSx.js} +1 -1
- package/dist/{softmax-B2_IKPDR.js → softmax-fsdtf6JC.js} +1 -1
- package/dist/{split-dcks18H1.js → split-eiktj-6L.js} +1 -1
- package/dist/{stack-lpJ5kYvE.js → stack-dfEEz2OY.js} +2 -2
- package/dist/{sum-CutF5lj2.js → sum-BE_Irnim.js} +1 -1
- package/dist/{tensor-C15NA2LA.js → tensor-Xyi595sG.js} +1 -1
- package/dist/{tensor2d-DZ_e5eKM.js → tensor2d-CPEkynbH.js} +1 -1
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +3 -3
- package/dist/training/sparseCrossEntropy.js +5 -5
- package/dist/utilities/dummy.d.ts +6 -0
- package/dist/utilities/dummy.js +31 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/profile.d.ts +5 -0
- package/dist/utilities/profile.js +10 -7
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/save.js +1 -1
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-CdRKKp8x.js → variable-wSS22xj5.js} +1 -1
- package/dist/{zeros-CAbHfODe.js → zeros-YJDE7oRb.js} +4 -4
- package/package.json +2 -8
- package/dist/Reshape-CLOrdpve.js +0 -212
- package/dist/slice_util-Ddk0uxGJ.js +0 -49
- package/dist/tfjs_backend-BDb8r9qx.js +0 -1010
- package/dist/utilities/load.js +0 -99
|
@@ -1,544 +1,13 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
import { r as
|
|
4
|
-
import {
|
|
5
|
-
|
|
6
|
-
/**
|
|
7
|
-
* @license
|
|
8
|
-
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
9
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
10
|
-
* you may not use this file except in compliance with the License.
|
|
11
|
-
* You may obtain a copy of the License at
|
|
12
|
-
*
|
|
13
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
14
|
-
*
|
|
15
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
16
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
17
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
|
-
* See the License for the specific language governing permissions and
|
|
19
|
-
* limitations under the License.
|
|
20
|
-
* =============================================================================
|
|
21
|
-
*/
|
|
22
|
-
const q = 30;
|
|
23
|
-
function F(a) {
|
|
24
|
-
return a <= q ? a : E(a, Math.floor(Math.sqrt(a)));
|
|
25
|
-
}
|
|
26
|
-
/**
|
|
27
|
-
* @license
|
|
28
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
29
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
30
|
-
* you may not use this file except in compliance with the License.
|
|
31
|
-
* You may obtain a copy of the License at
|
|
32
|
-
*
|
|
33
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
34
|
-
*
|
|
35
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
36
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
37
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
38
|
-
* See the License for the specific language governing permissions and
|
|
39
|
-
* limitations under the License.
|
|
40
|
-
* =============================================================================
|
|
41
|
-
*/
|
|
42
|
-
class A {
|
|
43
|
-
constructor(o, t) {
|
|
44
|
-
this.variableNames = ["x"];
|
|
45
|
-
const { windowSize: e, batchSize: n, inSize: u, outSize: l } = o;
|
|
46
|
-
this.outputShape = [n, l];
|
|
47
|
-
const s = Math.floor(e / 4) * 4, c = e % 4;
|
|
48
|
-
let i = "sumValue += dot(values, ones);";
|
|
49
|
-
if (t != null) {
|
|
50
|
-
const p = 1 / t;
|
|
51
|
-
i = `sumValue += dot(values * ${T(p) ? p.toPrecision(2) : p}, ones);`;
|
|
52
|
-
}
|
|
53
|
-
let r = "";
|
|
54
|
-
u % e > 0 && (r = `
|
|
55
|
-
if (inIdx < 0 || inIdx >= ${u}) {
|
|
56
|
-
return 0.0;
|
|
57
|
-
}
|
|
58
|
-
`), this.userCode = `
|
|
59
|
-
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
60
|
-
|
|
61
|
-
float getValue(int batch, int inIdx) {
|
|
62
|
-
${r}
|
|
63
|
-
return getX(batch, inIdx);
|
|
64
|
-
}
|
|
65
|
-
|
|
66
|
-
void main() {
|
|
67
|
-
ivec2 coords = getOutputCoords();
|
|
68
|
-
int batch = coords[0];
|
|
69
|
-
int outIdx = coords[1];
|
|
70
|
-
int inOffset = outIdx * ${e};
|
|
71
|
-
|
|
72
|
-
float sumValue = 0.0;
|
|
73
|
-
|
|
74
|
-
for (int i = 0; i < ${s}; i += 4) {
|
|
75
|
-
int inIdx = inOffset + i;
|
|
76
|
-
vec4 values = vec4(
|
|
77
|
-
getValue(batch, inIdx),
|
|
78
|
-
getValue(batch, inIdx + 1),
|
|
79
|
-
getValue(batch, inIdx + 2),
|
|
80
|
-
getValue(batch, inIdx + 3)
|
|
81
|
-
);
|
|
82
|
-
|
|
83
|
-
${i}
|
|
84
|
-
}
|
|
85
|
-
|
|
86
|
-
int inIdx = inOffset + ${s};
|
|
87
|
-
if (${c === 1}) {
|
|
88
|
-
vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
|
|
89
|
-
|
|
90
|
-
${i}
|
|
91
|
-
} else if (${c === 2}) {
|
|
92
|
-
vec4 values = vec4(
|
|
93
|
-
getValue(batch, inIdx),
|
|
94
|
-
getValue(batch, inIdx + 1), 0.0, 0.0);
|
|
95
|
-
|
|
96
|
-
${i}
|
|
97
|
-
} else if (${c === 3}) {
|
|
98
|
-
vec4 values = vec4(
|
|
99
|
-
getValue(batch, inIdx),
|
|
100
|
-
getValue(batch, inIdx + 1),
|
|
101
|
-
getValue(batch, inIdx + 2), 0.0);
|
|
102
|
-
|
|
103
|
-
${i}
|
|
104
|
-
}
|
|
105
|
-
setOutput(sumValue);
|
|
106
|
-
}
|
|
107
|
-
`;
|
|
108
|
-
}
|
|
109
|
-
}
|
|
110
|
-
/**
|
|
111
|
-
* @license
|
|
112
|
-
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
113
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
114
|
-
* you may not use this file except in compliance with the License.
|
|
115
|
-
* You may obtain a copy of the License at
|
|
116
|
-
*
|
|
117
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
118
|
-
*
|
|
119
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
120
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
121
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
122
|
-
* See the License for the specific language governing permissions and
|
|
123
|
-
* limitations under the License.
|
|
124
|
-
* =============================================================================
|
|
125
|
-
*/
|
|
126
|
-
class j {
|
|
127
|
-
constructor(o, t) {
|
|
128
|
-
this.variableNames = ["x"];
|
|
129
|
-
const { windowSize: e, batchSize: n, inSize: u, outSize: l } = o;
|
|
130
|
-
this.outputShape = [n, l];
|
|
131
|
-
let s = "0.0", c = "";
|
|
132
|
-
t === "prod" ? s = "1.0" : t === "min" ? (s = "1.0 / 1e-20", c = "min") : t === "max" && (s = "-1.0 / 1e-20", c = "max");
|
|
133
|
-
let i = `${t}(${t}(${t}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;
|
|
134
|
-
t === "sum" ? i = "sumValue" : t === "prod" ? i = "prodValue" : t === "all" ? i = "allValue" : t === "any" && (i = "anyValue");
|
|
135
|
-
const r = Math.floor(e / 4) * 4, p = e % 4;
|
|
136
|
-
let h = `
|
|
137
|
-
if (${t === "sum"}) {
|
|
138
|
-
sumValue += dot(values, ones);
|
|
139
|
-
} else if (${t === "prod"}) {
|
|
140
|
-
vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
|
|
141
|
-
prodValue *= tmp[0] * tmp[1];
|
|
142
|
-
} else {
|
|
143
|
-
minMaxValue = ${c}(values, minMaxValue);
|
|
144
|
-
if (${t === "min"} || ${t === "max"}) {
|
|
145
|
-
minMaxValue = ${c}(values, minMaxValue);
|
|
146
|
-
bvec4 isNaN = isnan(values);
|
|
147
|
-
if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
|
|
148
|
-
minMaxValue = vec4(NAN);
|
|
149
|
-
}
|
|
150
|
-
}
|
|
151
|
-
}
|
|
152
|
-
`, d = "vec4";
|
|
153
|
-
t === "all" ? (s = "1.0", h = `
|
|
154
|
-
bool reducedAllValue = all(values);
|
|
155
|
-
float floatedReducedAllValue = float(reducedAllValue);
|
|
156
|
-
allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
|
|
157
|
-
`, d = "bvec4") : t === "any" && (s = "0.0", h = `
|
|
158
|
-
bool reducedAnyValue = any(values);
|
|
159
|
-
float floatedReducedAnyValue = float(reducedAnyValue);
|
|
160
|
-
anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
|
|
161
|
-
`, d = "bvec4");
|
|
162
|
-
let f = "";
|
|
163
|
-
u % e > 0 && (f = `
|
|
164
|
-
if (inIdx < 0 || inIdx >= ${u}) {
|
|
165
|
-
return initializationValue;
|
|
166
|
-
}
|
|
167
|
-
`), this.userCode = `
|
|
168
|
-
const float initializationValue = ${s};
|
|
169
|
-
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
170
|
-
|
|
171
|
-
float getValue(int batch, int inIdx) {
|
|
172
|
-
${f}
|
|
173
|
-
return getX(batch, inIdx);
|
|
174
|
-
}
|
|
175
|
-
|
|
176
|
-
void main() {
|
|
177
|
-
ivec2 coords = getOutputCoords();
|
|
178
|
-
int batch = coords[0];
|
|
179
|
-
int outIdx = coords[1];
|
|
180
|
-
int inOffset = outIdx * ${e};
|
|
181
|
-
|
|
182
|
-
vec4 minMaxValue = vec4(${s});
|
|
183
|
-
float prodValue = 1.0;
|
|
184
|
-
float sumValue = 0.0;
|
|
185
|
-
float allValue = 1.0;
|
|
186
|
-
float anyValue = 0.0;
|
|
187
|
-
|
|
188
|
-
for (int i = 0; i < ${r}; i += 4) {
|
|
189
|
-
int inIdx = inOffset + i;
|
|
190
|
-
${d} values = ${d}(
|
|
191
|
-
getValue(batch, inIdx),
|
|
192
|
-
getValue(batch, inIdx + 1),
|
|
193
|
-
getValue(batch, inIdx + 2),
|
|
194
|
-
getValue(batch, inIdx + 3)
|
|
195
|
-
);
|
|
196
|
-
|
|
197
|
-
${h}
|
|
198
|
-
}
|
|
199
|
-
|
|
200
|
-
int inIdx = inOffset + ${r};
|
|
201
|
-
if (${p === 1}) {
|
|
202
|
-
${d} values = ${d}(
|
|
203
|
-
getValue(batch, inIdx),
|
|
204
|
-
initializationValue,
|
|
205
|
-
initializationValue,
|
|
206
|
-
initializationValue
|
|
207
|
-
);
|
|
208
|
-
|
|
209
|
-
${h}
|
|
210
|
-
} else if (${p === 2}) {
|
|
211
|
-
${d} values = ${d}(
|
|
212
|
-
getValue(batch, inIdx),
|
|
213
|
-
getValue(batch, inIdx + 1),
|
|
214
|
-
initializationValue,
|
|
215
|
-
initializationValue
|
|
216
|
-
);
|
|
217
|
-
|
|
218
|
-
${h}
|
|
219
|
-
} else if (${p === 3}) {
|
|
220
|
-
${d} values = ${d}(
|
|
221
|
-
getValue(batch, inIdx),
|
|
222
|
-
getValue(batch, inIdx + 1),
|
|
223
|
-
getValue(batch, inIdx + 2),
|
|
224
|
-
initializationValue
|
|
225
|
-
);
|
|
226
|
-
|
|
227
|
-
${h}
|
|
228
|
-
}
|
|
229
|
-
setOutput(${i});
|
|
230
|
-
}
|
|
231
|
-
`;
|
|
232
|
-
}
|
|
233
|
-
}
|
|
234
|
-
/**
|
|
235
|
-
* @license
|
|
236
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
237
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
238
|
-
* you may not use this file except in compliance with the License.
|
|
239
|
-
* You may obtain a copy of the License at
|
|
240
|
-
*
|
|
241
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
242
|
-
*
|
|
243
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
244
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
245
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
246
|
-
* See the License for the specific language governing permissions and
|
|
247
|
-
* limitations under the License.
|
|
248
|
-
* =============================================================================
|
|
249
|
-
*/
|
|
250
|
-
function H(a) {
|
|
251
|
-
const o = [];
|
|
252
|
-
for (; o.length === 0 || o[o.length - 1].outSize !== 1; ) {
|
|
253
|
-
const t = o.length ? o[o.length - 1].outSize : a[1], e = F(t);
|
|
254
|
-
o.push({
|
|
255
|
-
inSize: t,
|
|
256
|
-
windowSize: e,
|
|
257
|
-
outSize: Math.ceil(t / e)
|
|
258
|
-
});
|
|
259
|
-
}
|
|
260
|
-
return o;
|
|
261
|
-
}
|
|
262
|
-
function N(a, o, t, e) {
|
|
263
|
-
const n = H(a.shape);
|
|
264
|
-
let u = a;
|
|
265
|
-
for (let l = 0; l < n.length; l++) {
|
|
266
|
-
const { inSize: s, windowSize: c, outSize: i } = n[l];
|
|
267
|
-
let r, p;
|
|
268
|
-
t === "mean" ? r = l === 0 ? new A({ windowSize: c, inSize: s, batchSize: a.shape[0], outSize: i }, s) : new A({ windowSize: c, inSize: s, batchSize: a.shape[0], outSize: i }) : r = new j({ windowSize: c, inSize: s, batchSize: a.shape[0], outSize: i }, t), p = u, u = e.runWebGLProgram(r, [u], o), p.dataId !== a.dataId && e.disposeIntermediateTensorInfo(p);
|
|
269
|
-
}
|
|
270
|
-
return u;
|
|
271
|
-
}
|
|
272
|
-
/**
|
|
273
|
-
* @license
|
|
274
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
275
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
276
|
-
* you may not use this file except in compliance with the License.
|
|
277
|
-
* You may obtain a copy of the License at
|
|
278
|
-
*
|
|
279
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
280
|
-
*
|
|
281
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
282
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
283
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
284
|
-
* See the License for the specific language governing permissions and
|
|
285
|
-
* limitations under the License.
|
|
286
|
-
* =============================================================================
|
|
287
|
-
*/
|
|
288
|
-
function X(a, o, t, e) {
|
|
289
|
-
const n = V(o), l = V(a.shape) / n, s = b({ inputs: { x: a }, attrs: { shape: [l, n] }, backend: e }), c = N(s, a.dtype, "max", e), i = b({ inputs: { x: c }, attrs: { shape: t }, backend: e });
|
|
290
|
-
return e.disposeIntermediateTensorInfo(s), e.disposeIntermediateTensorInfo(c), i;
|
|
291
|
-
}
|
|
292
|
-
/**
|
|
293
|
-
* @license
|
|
294
|
-
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
295
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
296
|
-
* you may not use this file except in compliance with the License.
|
|
297
|
-
* You may obtain a copy of the License at
|
|
298
|
-
*
|
|
299
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
300
|
-
*
|
|
301
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
302
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
303
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
304
|
-
* See the License for the specific language governing permissions and
|
|
305
|
-
* limitations under the License.
|
|
306
|
-
* =============================================================================
|
|
307
|
-
*/
|
|
308
|
-
class Y {
|
|
309
|
-
constructor(o, t) {
|
|
310
|
-
this.variableNames = ["A"];
|
|
311
|
-
const e = new Array(o.length);
|
|
312
|
-
for (let l = 0; l < e.length; l++)
|
|
313
|
-
e[l] = o[t[l]];
|
|
314
|
-
this.outputShape = e, this.rank = e.length;
|
|
315
|
-
const n = C(this.rank), u = Z(t);
|
|
316
|
-
this.userCode = `
|
|
317
|
-
void main() {
|
|
318
|
-
${n} resRC = getOutputCoords();
|
|
319
|
-
setOutput(getA(${u}));
|
|
320
|
-
}
|
|
321
|
-
`;
|
|
322
|
-
}
|
|
323
|
-
}
|
|
324
|
-
function Z(a) {
|
|
325
|
-
const o = a.length;
|
|
326
|
-
if (o > 6)
|
|
327
|
-
throw Error(`Transpose for rank ${o} is not yet supported`);
|
|
328
|
-
const t = ["resRC.x", "resRC.y", "resRC.z", "resRC.w", "resRC.u", "resRC.v"], e = new Array(o);
|
|
329
|
-
for (let n = 0; n < a.length; n++)
|
|
330
|
-
e[a[n]] = t[n];
|
|
331
|
-
return e.join();
|
|
332
|
-
}
|
|
333
|
-
/**
|
|
334
|
-
* @license
|
|
335
|
-
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
336
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
337
|
-
* you may not use this file except in compliance with the License.
|
|
338
|
-
* You may obtain a copy of the License at
|
|
339
|
-
*
|
|
340
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
341
|
-
*
|
|
342
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
343
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
344
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
345
|
-
* See the License for the specific language governing permissions and
|
|
346
|
-
* limitations under the License.
|
|
347
|
-
* =============================================================================
|
|
348
|
-
*/
|
|
349
|
-
class J {
|
|
350
|
-
constructor(o, t) {
|
|
351
|
-
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
|
|
352
|
-
const e = new Array(o.length);
|
|
353
|
-
for (let r = 0; r < e.length; r++)
|
|
354
|
-
e[r] = o[t[r]];
|
|
355
|
-
if (this.outputShape = e, this.rank = e.length, this.rank > 6)
|
|
356
|
-
throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
|
|
357
|
-
const n = C(this.rank), u = U("rc", this.rank), l = new Array(this.rank);
|
|
358
|
-
for (let r = 0; r < t.length; r++)
|
|
359
|
-
l[t[r]] = u[r];
|
|
360
|
-
const s = `vec2(${l.slice(-2).join()})`, c = `++${u[this.rank - 1]} < ${e[this.rank - 1]}`, i = `getChannel(getA(${l.join()}), ${s})`;
|
|
361
|
-
this.userCode = `
|
|
362
|
-
void main() {
|
|
363
|
-
${n} rc = getOutputCoords();
|
|
364
|
-
vec4 result = vec4(0.);
|
|
365
|
-
result[0] = ${i};
|
|
366
|
-
if(${c}) {
|
|
367
|
-
result[1] = ${i};
|
|
368
|
-
}
|
|
369
|
-
--${u[this.rank - 1]};
|
|
370
|
-
if(++${u[this.rank - 2]} < ${e[this.rank - 2]}) {
|
|
371
|
-
result[2] = ${i};
|
|
372
|
-
if(${c}) {
|
|
373
|
-
result[3] = ${i};
|
|
374
|
-
}
|
|
375
|
-
}
|
|
376
|
-
setOutput(result);
|
|
377
|
-
}
|
|
378
|
-
`;
|
|
379
|
-
}
|
|
380
|
-
}
|
|
381
|
-
/**
|
|
382
|
-
* @license
|
|
383
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
384
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
385
|
-
* you may not use this file except in compliance with the License.
|
|
386
|
-
* You may obtain a copy of the License at
|
|
387
|
-
*
|
|
388
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
389
|
-
*
|
|
390
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
391
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
392
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
393
|
-
* See the License for the specific language governing permissions and
|
|
394
|
-
* limitations under the License.
|
|
395
|
-
* =============================================================================
|
|
396
|
-
*/
|
|
397
|
-
function D(a, o, t) {
|
|
398
|
-
const e = L().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, o) : new Y(a.shape, o);
|
|
399
|
-
return t.runWebGLProgram(e, [a], a.dtype);
|
|
400
|
-
}
|
|
401
|
-
/**
|
|
402
|
-
* @license
|
|
403
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
404
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
405
|
-
* you may not use this file except in compliance with the License.
|
|
406
|
-
* You may obtain a copy of the License at
|
|
407
|
-
*
|
|
408
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
409
|
-
*
|
|
410
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
411
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
412
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
413
|
-
* See the License for the specific language governing permissions and
|
|
414
|
-
* limitations under the License.
|
|
415
|
-
* =============================================================================
|
|
416
|
-
*/
|
|
417
|
-
function Q(a) {
|
|
418
|
-
const { inputs: o, backend: t, attrs: e } = a, { x: n } = o, { reductionIndices: u, keepDims: l } = e, s = n.shape.length, c = w(u, n.shape);
|
|
419
|
-
let i = c;
|
|
420
|
-
const r = k(i, s), p = r != null, h = t.shouldExecuteOnCPU([n]);
|
|
421
|
-
let d = n;
|
|
422
|
-
if (p) {
|
|
423
|
-
if (h) {
|
|
424
|
-
const I = t.texData.get(d.dataId).values, g = new Array(s);
|
|
425
|
-
for (let $ = 0; $ < g.length; $++)
|
|
426
|
-
g[$] = n.shape[r[$]];
|
|
427
|
-
const z = B(I, n.shape, n.dtype, r, g);
|
|
428
|
-
d = t.makeTensorInfo(g, n.dtype);
|
|
429
|
-
const M = t.texData.get(d.dataId);
|
|
430
|
-
M.values = z;
|
|
431
|
-
} else
|
|
432
|
-
d = D(n, r, t);
|
|
433
|
-
i = R(i.length, s);
|
|
434
|
-
}
|
|
435
|
-
P("max", i, s);
|
|
436
|
-
const [f, v] = y(d.shape, i);
|
|
437
|
-
let x = f;
|
|
438
|
-
l && (x = O(f, c));
|
|
439
|
-
let m;
|
|
440
|
-
if (h) {
|
|
441
|
-
const I = t.texData.get(d.dataId).values, g = G(I, V(v), x, n.dtype);
|
|
442
|
-
m = t.makeTensorInfo(x, n.dtype);
|
|
443
|
-
const z = t.texData.get(m.dataId);
|
|
444
|
-
z.values = g;
|
|
445
|
-
} else
|
|
446
|
-
m = X(d, v, x, t);
|
|
447
|
-
return p && t.disposeIntermediateTensorInfo(d), m;
|
|
448
|
-
}
|
|
449
|
-
/**
|
|
450
|
-
* @license
|
|
451
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
452
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
453
|
-
* you may not use this file except in compliance with the License.
|
|
454
|
-
* You may obtain a copy of the License at
|
|
455
|
-
*
|
|
456
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
457
|
-
*
|
|
458
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
459
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
460
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
461
|
-
* See the License for the specific language governing permissions and
|
|
462
|
-
* limitations under the License.
|
|
463
|
-
* =============================================================================
|
|
464
|
-
*/
|
|
465
|
-
function ee(a, o, t, e) {
|
|
466
|
-
const n = o, u = a.shape.length, l = w(n, a.shape);
|
|
467
|
-
let s = l;
|
|
468
|
-
const c = k(s, u), i = c != null;
|
|
469
|
-
let r = a;
|
|
470
|
-
i && (r = D(a, c, e), s = R(s.length, u)), P("sum", s, u);
|
|
471
|
-
const [p, h] = y(r.shape, s);
|
|
472
|
-
let d = p;
|
|
473
|
-
t && (d = O(p, l));
|
|
474
|
-
const f = V(h), x = V(a.shape) / f, m = b({ inputs: { x: r }, attrs: { shape: [x, f] }, backend: e }), S = K(a.dtype), I = N(m, S, "sum", e), g = b({ inputs: { x: I }, attrs: { shape: d }, backend: e });
|
|
475
|
-
return e.disposeIntermediateTensorInfo(m), e.disposeIntermediateTensorInfo(I), i && e.disposeIntermediateTensorInfo(r), g;
|
|
476
|
-
}
|
|
477
|
-
/**
|
|
478
|
-
* @license
|
|
479
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
480
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
481
|
-
* you may not use this file except in compliance with the License.
|
|
482
|
-
* You may obtain a copy of the License at
|
|
483
|
-
*
|
|
484
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
485
|
-
*
|
|
486
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
487
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
488
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
489
|
-
* See the License for the specific language governing permissions and
|
|
490
|
-
* limitations under the License.
|
|
491
|
-
* =============================================================================
|
|
492
|
-
*/
|
|
493
|
-
function te(a) {
|
|
494
|
-
const { inputs: o, backend: t, attrs: e } = a, { x: n } = o, { axis: u, keepDims: l } = e;
|
|
495
|
-
return ee(n, u, l, t);
|
|
496
|
-
}
|
|
497
|
-
/**
|
|
498
|
-
* @license
|
|
499
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
500
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
501
|
-
* you may not use this file except in compliance with the License.
|
|
502
|
-
* You may obtain a copy of the License at
|
|
503
|
-
*
|
|
504
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
505
|
-
*
|
|
506
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
507
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
508
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
509
|
-
* See the License for the specific language governing permissions and
|
|
510
|
-
* limitations under the License.
|
|
511
|
-
* =============================================================================
|
|
512
|
-
*/
|
|
513
|
-
const ae = `
|
|
514
|
-
if (a == b) {
|
|
515
|
-
return 1.0;
|
|
516
|
-
};
|
|
517
|
-
return a / b;`, se = `
|
|
518
|
-
// vec4 one = vec4(equal(a, b));
|
|
519
|
-
// return one + (vec4(1.0) - one) * a / b;
|
|
520
|
-
vec4 result = a / b;
|
|
521
|
-
if(a.x == b.x) {
|
|
522
|
-
result.x = 1.;
|
|
523
|
-
}
|
|
524
|
-
if(a.y == b.y) {
|
|
525
|
-
result.y = 1.;
|
|
526
|
-
}
|
|
527
|
-
if(a.z == b.z) {
|
|
528
|
-
result.z = 1.;
|
|
529
|
-
}
|
|
530
|
-
if(a.w == b.w) {
|
|
531
|
-
result.w = 1.;
|
|
532
|
-
}
|
|
533
|
-
|
|
534
|
-
return result;
|
|
535
|
-
`, oe = _({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 });
|
|
536
|
-
class ne {
|
|
1
|
+
import { m as b, s as I, r as k } from "../../RealDiv-BYViZwhN.js";
|
|
2
|
+
import { r as v } from "../../Reshape-t7Kcikjk.js";
|
|
3
|
+
import { r as w, p as P } from "../../index-BAzbokzv.js";
|
|
4
|
+
import { e as S } from "../../axis_util-Bu4h7XWV.js";
|
|
5
|
+
class T {
|
|
537
6
|
variableNames = ["logits", "maxLogits"];
|
|
538
7
|
outputShape;
|
|
539
8
|
userCode;
|
|
540
|
-
constructor(
|
|
541
|
-
this.outputShape =
|
|
9
|
+
constructor(t) {
|
|
10
|
+
this.outputShape = t, this.userCode = `
|
|
542
11
|
void main() {
|
|
543
12
|
ivec4 coords = getOutputCoords(); // [batch, nh, t1, t2]
|
|
544
13
|
int b = coords.x;
|
|
@@ -552,7 +21,7 @@ class ne {
|
|
|
552
21
|
`;
|
|
553
22
|
}
|
|
554
23
|
}
|
|
555
|
-
class
|
|
24
|
+
class C {
|
|
556
25
|
variableNames = ["exp", "sum"];
|
|
557
26
|
outputShape;
|
|
558
27
|
userCode;
|
|
@@ -560,8 +29,8 @@ class ie {
|
|
|
560
29
|
{ name: "dropoutRate", type: "float" },
|
|
561
30
|
{ name: "seed", type: "float" }
|
|
562
31
|
];
|
|
563
|
-
constructor(
|
|
564
|
-
this.outputShape =
|
|
32
|
+
constructor(t) {
|
|
33
|
+
this.outputShape = t, this.userCode = `
|
|
565
34
|
float random(ivec4 coords) {
|
|
566
35
|
float x = float(coords.x * 4096 + coords.y * 256 + coords.z * 16 + coords.w);
|
|
567
36
|
return fract(sin(seed + x) * 43758.5453123);
|
|
@@ -579,33 +48,33 @@ class ie {
|
|
|
579
48
|
`;
|
|
580
49
|
}
|
|
581
50
|
}
|
|
582
|
-
function
|
|
583
|
-
const { inputs:
|
|
51
|
+
function L(r) {
|
|
52
|
+
const { inputs: t, attrs: m } = r, { logits: e } = t, { dim: u, dropoutRate: n, seed: c } = m, o = r.backend;
|
|
584
53
|
if (!e)
|
|
585
54
|
throw new Error("Error in softmax: input logits is null");
|
|
586
|
-
const
|
|
55
|
+
const i = P([u], e.shape), d = b({
|
|
587
56
|
inputs: { x: e },
|
|
588
|
-
backend:
|
|
589
|
-
attrs: { reductionIndices:
|
|
590
|
-
}),
|
|
591
|
-
|
|
592
|
-
const
|
|
593
|
-
if (
|
|
594
|
-
const
|
|
595
|
-
[
|
|
596
|
-
[
|
|
57
|
+
backend: o,
|
|
58
|
+
attrs: { reductionIndices: i, keepDims: !1 }
|
|
59
|
+
}), f = S(d.shape, i), l = new T(e.shape), s = o.runWebGLProgram(l, [e, d], "float32");
|
|
60
|
+
o.disposeIntermediateTensorInfo(d);
|
|
61
|
+
const p = I({ inputs: { x: s }, backend: o, attrs: { axis: i, keepDims: !1 } }), a = v({ inputs: { x: p }, backend: o, attrs: { shape: f } });
|
|
62
|
+
if (n !== void 0 && n > 0) {
|
|
63
|
+
const g = new C(e.shape), h = o.runWebGLProgram(g, [s, a], "float32", [
|
|
64
|
+
[n],
|
|
65
|
+
[c ?? Math.random() * 1e4]
|
|
597
66
|
]);
|
|
598
|
-
return
|
|
67
|
+
return o.disposeIntermediateTensorInfo(s), o.disposeIntermediateTensorInfo(p), o.disposeIntermediateTensorInfo(a), h;
|
|
599
68
|
}
|
|
600
|
-
const
|
|
601
|
-
return
|
|
69
|
+
const x = k({ inputs: { a: s, b: a }, backend: o });
|
|
70
|
+
return o.disposeIntermediateTensorInfo(s), o.disposeIntermediateTensorInfo(p), o.disposeIntermediateTensorInfo(a), x;
|
|
602
71
|
}
|
|
603
|
-
const
|
|
72
|
+
const E = {
|
|
604
73
|
kernelName: "FusedSoftmax",
|
|
605
74
|
backendName: "webgl",
|
|
606
|
-
kernelFunc:
|
|
75
|
+
kernelFunc: L
|
|
607
76
|
};
|
|
608
|
-
|
|
77
|
+
w(E);
|
|
609
78
|
export {
|
|
610
|
-
|
|
79
|
+
L as softmax
|
|
611
80
|
};
|
package/dist/ops/webgl/gelu.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { r as a } from "../../index-
|
|
2
|
-
import { u as s, C as x } from "../../kernel_funcs_utils-
|
|
1
|
+
import { r as a } from "../../index-BAzbokzv.js";
|
|
2
|
+
import { u as s, C as x } from "../../kernel_funcs_utils-CUxJCg0g.js";
|
|
3
3
|
const t = 0.7978845608028654, r = 0.044715, c = x + `
|
|
4
4
|
float x3 = x * x * x;
|
|
5
5
|
float inner = x + ${r} * x3;
|
package/dist/ops/webgl/log.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { r,
|
|
2
|
-
import { u as s,
|
|
3
|
-
import { l } from "../../shared-
|
|
1
|
+
import { r, a9 as e } from "../../index-BAzbokzv.js";
|
|
2
|
+
import { u as s, l as N } from "../../kernel_funcs_utils-CUxJCg0g.js";
|
|
3
|
+
import { aG as l } from "../../shared-Ca6iDobD.js";
|
|
4
4
|
/**
|
|
5
5
|
* @license
|
|
6
6
|
* Copyright 2020 Google LLC. All Rights Reserved.
|