@genai-fi/nanogpt 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/dist/Generator.js +7 -7
  2. package/dist/NanoGPTModel.js +70 -121
  3. package/dist/RealDiv-7xu-pkZN.js +540 -0
  4. package/dist/Reshape-BYC1oUku.js +127 -0
  5. package/dist/TeachableLLM.d.ts +2 -0
  6. package/dist/TeachableLLM.js +34 -27
  7. package/dist/{TiedEmbedding-BhxWO8QR.js → TiedEmbedding-C1HBot-5.js} +12 -13
  8. package/dist/{axis_util-D17qZRQm.js → axis_util-CCNL7jea.js} +14 -12
  9. package/dist/{broadcast_to-BMQLjvt_.js → broadcast_to-CddAF879.js} +2 -2
  10. package/dist/{concat-DhZfF1GY.js → concat-XOK9ANZu.js} +7 -7
  11. package/dist/{dataset-oilnemHf.js → dataset-BFFipD1c.js} +5 -5
  12. package/dist/{dropout-CrMQPCeG.js → dropout-xlKRoJyU.js} +9 -9
  13. package/dist/{gather-DZCMHZuN.js → gather-DKtUaTtA.js} +1 -1
  14. package/dist/gpgpu_math-B_ycgZ4W.js +3115 -0
  15. package/dist/{index-bMBtI-WR.js → index-CamYe_M8.js} +843 -646
  16. package/dist/{kernel_funcs_utils-CNmjLWnB.js → kernel_funcs_utils-D5MS0JFg.js} +232 -138
  17. package/dist/layers/BaseLayer.js +2 -2
  18. package/dist/layers/CausalSelfAttention.js +6 -6
  19. package/dist/layers/MLP.js +5 -5
  20. package/dist/layers/RMSNorm.js +3 -3
  21. package/dist/layers/RoPECache.js +13 -33
  22. package/dist/layers/TiedEmbedding.js +6 -7
  23. package/dist/layers/TransformerBlock.js +1 -1
  24. package/dist/{log_sum_exp-BHdkCb4s.js → log_sum_exp-CV_5-TTu.js} +15 -15
  25. package/dist/main.js +23 -20
  26. package/dist/{mat_mul-BsrLfy81.js → mat_mul-CAbRFWUj.js} +4 -4
  27. package/dist/{max-DechV4Bc.js → max-JBBv7aUf.js} +3 -3
  28. package/dist/mulmat_packed_gpu-DW4doKL_.js +71 -0
  29. package/dist/{norm-B9hWHZH1.js → norm-B9dQTFYn.js} +12 -12
  30. package/dist/{ones-g0K8jVwm.js → ones-CMHNqMr6.js} +2 -2
  31. package/dist/ops/appendCache.js +3 -3
  32. package/dist/ops/attentionMask.js +1 -1
  33. package/dist/ops/cpu/appendCache.js +2 -2
  34. package/dist/ops/cpu/attentionMask.js +5 -5
  35. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  36. package/dist/ops/cpu/gatherSub.js +5 -5
  37. package/dist/ops/cpu/gelu.js +1 -1
  38. package/dist/ops/cpu/matMulGelu.js +1 -1
  39. package/dist/ops/cpu/matMulMul.js +1 -1
  40. package/dist/ops/cpu/mulDropout.js +1 -1
  41. package/dist/ops/cpu/normRMS.js +1 -1
  42. package/dist/ops/cpu/qkv.js +3 -3
  43. package/dist/ops/cpu/rope.js +5 -5
  44. package/dist/ops/cpu/scatterSub.js +18 -49
  45. package/dist/ops/fusedSoftmax.js +1 -1
  46. package/dist/ops/gatherSub.js +1 -1
  47. package/dist/ops/gelu.js +1 -1
  48. package/dist/ops/grads/attentionMask.js +1 -1
  49. package/dist/ops/grads/fusedSoftmax.js +2 -2
  50. package/dist/ops/grads/gelu.js +1 -1
  51. package/dist/ops/grads/matMulGelu.js +1 -1
  52. package/dist/ops/grads/normRMS.js +1 -1
  53. package/dist/ops/grads/qkv.js +1 -1
  54. package/dist/ops/grads/rope.js +1 -1
  55. package/dist/ops/matMulGelu.js +1 -1
  56. package/dist/ops/matMulMul.js +1 -1
  57. package/dist/ops/mulDrop.js +1 -1
  58. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  59. package/dist/ops/normRMS.js +1 -1
  60. package/dist/ops/qkv.js +1 -1
  61. package/dist/ops/rope.js +8 -4
  62. package/dist/ops/scatterSub.js +1 -1
  63. package/dist/ops/webgl/appendCache.js +1 -1
  64. package/dist/ops/webgl/attentionMask.js +1 -1
  65. package/dist/ops/webgl/fusedSoftmax.js +29 -560
  66. package/dist/ops/webgl/gatherSub.js +1 -1
  67. package/dist/ops/webgl/gelu.js +2 -2
  68. package/dist/ops/webgl/log.js +3 -3
  69. package/dist/ops/webgl/matMulGelu.js +48 -115
  70. package/dist/ops/webgl/matMulMul.js +1 -1
  71. package/dist/ops/webgl/mulDropout.js +1 -1
  72. package/dist/ops/webgl/normRMS.js +2 -2
  73. package/dist/ops/webgl/qkv.js +1 -1
  74. package/dist/ops/webgl/rope.js +1 -1
  75. package/dist/ops/webgl/scatterSub.js +1 -1
  76. package/dist/{ops-Mv7Ta72x.js → ops-DqtYemmV.js} +143 -135
  77. package/dist/{random_width-BBAWzDym.js → random_width-CLMQG5Jn.js} +6925 -6291
  78. package/dist/{range-DMaG9A3G.js → range-DqYjKnuG.js} +1 -1
  79. package/dist/{gpgpu_math-Ctc31slO.js → reciprocal-z49filta.js} +7 -5
  80. package/dist/register_all_kernels-COt6wLD0.js +21397 -0
  81. package/dist/{reshape-T4yDEqoF.js → reshape-C45vIIRU.js} +1 -1
  82. package/dist/scatter_nd_util-qgtnviTE.js +46 -0
  83. package/dist/selu_util-4QV_GXTB.js +740 -0
  84. package/dist/{shared-XNAoXhOa.js → shared-ByfrGA97.js} +1462 -1089
  85. package/dist/{sin-EEhbrRO_.js → sin-9JBrfVaB.js} +1 -1
  86. package/dist/{softmax-B2_IKPDR.js → softmax-DvMvui-_.js} +1 -1
  87. package/dist/{split-dcks18H1.js → split-DxrHrPFK.js} +4 -4
  88. package/dist/{stack-lpJ5kYvE.js → stack-DgaoDmnF.js} +1 -1
  89. package/dist/{sum-CutF5lj2.js → sum-BpcpxNEh.js} +3 -3
  90. package/dist/{tensor-C15NA2LA.js → tensor-CDz5x1mP.js} +1 -1
  91. package/dist/{tensor2d-DZ_e5eKM.js → tensor2d-jO8JY5Jd.js} +1 -1
  92. package/dist/training/AdamExt.js +1 -1
  93. package/dist/training/DatasetBuilder.js +2 -2
  94. package/dist/training/FullTrainer.js +1 -1
  95. package/dist/training/Trainer.js +3 -3
  96. package/dist/training/sparseCrossEntropy.js +4 -4
  97. package/dist/utilities/dummy.d.ts +6 -0
  98. package/dist/utilities/dummy.js +31 -10
  99. package/dist/utilities/generate.js +3 -3
  100. package/dist/utilities/load.js +1 -1
  101. package/dist/utilities/profile.d.ts +5 -0
  102. package/dist/utilities/profile.js +10 -7
  103. package/dist/utilities/safetensors.js +2 -2
  104. package/dist/utilities/weights.js +2 -2
  105. package/dist/{variable-CdRKKp8x.js → variable-CLVXjN7F.js} +1 -1
  106. package/dist/{zeros-CAbHfODe.js → zeros-DUkkVccu.js} +8 -8
  107. package/package.json +3 -9
  108. package/dist/Reshape-CLOrdpve.js +0 -212
  109. package/dist/slice_util-Ddk0uxGJ.js +0 -49
  110. package/dist/tfjs_backend-BDb8r9qx.js +0 -1010
package/dist/ops/qkv.js CHANGED
@@ -1,4 +1,4 @@
1
- import { e as o } from "../index-bMBtI-WR.js";
1
+ import { e as o } from "../index-CamYe_M8.js";
2
2
  import "./cpu/qkv.js";
3
3
  import "./webgl/qkv.js";
4
4
  import "./grads/qkv.js";
package/dist/ops/rope.js CHANGED
@@ -1,10 +1,14 @@
1
- import { engine as n } from "@tensorflow/tfjs";
1
+ import { e as p } from "../index-CamYe_M8.js";
2
+ import "../random_width-CLMQG5Jn.js";
3
+ import "../register_all_kernels-COt6wLD0.js";
4
+ import "../index-Tf7vU29b.js";
5
+ import "../dataset-BFFipD1c.js";
2
6
  import "./cpu/rope.js";
3
7
  import "./webgl/rope.js";
4
8
  import "./grads/rope.js";
5
- function s(o, e, r) {
6
- return e.ensureRopeCache(o.shape[1] + r), n().runKernel("Rope", { x: o, sin: e.getSin(), cos: e.getCos() }, { pastLen: r });
9
+ function C(r, o, e) {
10
+ return o.ensureRopeCache(r.shape[1] + e), p().runKernel("Rope", { x: r, sin: o.getSin(), cos: o.getCos() }, { pastLen: e });
7
11
  }
8
12
  export {
9
- s as rope
13
+ C as rope
10
14
  };
@@ -1,4 +1,4 @@
1
- import { e as i } from "../index-bMBtI-WR.js";
1
+ import { e as i } from "../index-CamYe_M8.js";
2
2
  import "./cpu/scatterSub.js";
3
3
  import "./webgl/scatterSub.js";
4
4
  function c(t, r, e) {
@@ -1,4 +1,4 @@
1
- import { r as p } from "../../index-bMBtI-WR.js";
1
+ import { r as p } from "../../index-CamYe_M8.js";
2
2
  class m {
3
3
  variableNames = ["cache", "item"];
4
4
  outputShape;
@@ -1,4 +1,4 @@
1
- import { r as m } from "../../index-bMBtI-WR.js";
1
+ import { r as m } from "../../index-CamYe_M8.js";
2
2
  class h {
3
3
  variableNames = ["q", "k"];
4
4
  outputShape;
@@ -1,544 +1,13 @@
1
- import { ao as E, ap as T, q as V, N as L, a1 as w, aq as K, r as W } from "../../index-bMBtI-WR.js";
2
- import { t as B, m as G } from "../../shared-XNAoXhOa.js";
3
- import { r as b } from "../../Reshape-CLOrdpve.js";
4
- import { g as C, a as U, b as _ } from "../../kernel_funcs_utils-CNmjLWnB.js";
5
- import { g as k, a as R, b as P, c as y, e as O } from "../../axis_util-D17qZRQm.js";
6
- /**
7
- * @license
8
- * Copyright 2017 Google LLC. All Rights Reserved.
9
- * Licensed under the Apache License, Version 2.0 (the "License");
10
- * you may not use this file except in compliance with the License.
11
- * You may obtain a copy of the License at
12
- *
13
- * http://www.apache.org/licenses/LICENSE-2.0
14
- *
15
- * Unless required by applicable law or agreed to in writing, software
16
- * distributed under the License is distributed on an "AS IS" BASIS,
17
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- * See the License for the specific language governing permissions and
19
- * limitations under the License.
20
- * =============================================================================
21
- */
22
- const q = 30;
23
- function F(a) {
24
- return a <= q ? a : E(a, Math.floor(Math.sqrt(a)));
25
- }
26
- /**
27
- * @license
28
- * Copyright 2020 Google LLC. All Rights Reserved.
29
- * Licensed under the Apache License, Version 2.0 (the "License");
30
- * you may not use this file except in compliance with the License.
31
- * You may obtain a copy of the License at
32
- *
33
- * http://www.apache.org/licenses/LICENSE-2.0
34
- *
35
- * Unless required by applicable law or agreed to in writing, software
36
- * distributed under the License is distributed on an "AS IS" BASIS,
37
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38
- * See the License for the specific language governing permissions and
39
- * limitations under the License.
40
- * =============================================================================
41
- */
42
- class A {
43
- constructor(o, t) {
44
- this.variableNames = ["x"];
45
- const { windowSize: e, batchSize: n, inSize: u, outSize: l } = o;
46
- this.outputShape = [n, l];
47
- const s = Math.floor(e / 4) * 4, c = e % 4;
48
- let i = "sumValue += dot(values, ones);";
49
- if (t != null) {
50
- const p = 1 / t;
51
- i = `sumValue += dot(values * ${T(p) ? p.toPrecision(2) : p}, ones);`;
52
- }
53
- let r = "";
54
- u % e > 0 && (r = `
55
- if (inIdx < 0 || inIdx >= ${u}) {
56
- return 0.0;
57
- }
58
- `), this.userCode = `
59
- const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
60
-
61
- float getValue(int batch, int inIdx) {
62
- ${r}
63
- return getX(batch, inIdx);
64
- }
65
-
66
- void main() {
67
- ivec2 coords = getOutputCoords();
68
- int batch = coords[0];
69
- int outIdx = coords[1];
70
- int inOffset = outIdx * ${e};
71
-
72
- float sumValue = 0.0;
73
-
74
- for (int i = 0; i < ${s}; i += 4) {
75
- int inIdx = inOffset + i;
76
- vec4 values = vec4(
77
- getValue(batch, inIdx),
78
- getValue(batch, inIdx + 1),
79
- getValue(batch, inIdx + 2),
80
- getValue(batch, inIdx + 3)
81
- );
82
-
83
- ${i}
84
- }
85
-
86
- int inIdx = inOffset + ${s};
87
- if (${c === 1}) {
88
- vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
89
-
90
- ${i}
91
- } else if (${c === 2}) {
92
- vec4 values = vec4(
93
- getValue(batch, inIdx),
94
- getValue(batch, inIdx + 1), 0.0, 0.0);
95
-
96
- ${i}
97
- } else if (${c === 3}) {
98
- vec4 values = vec4(
99
- getValue(batch, inIdx),
100
- getValue(batch, inIdx + 1),
101
- getValue(batch, inIdx + 2), 0.0);
102
-
103
- ${i}
104
- }
105
- setOutput(sumValue);
106
- }
107
- `;
108
- }
109
- }
110
- /**
111
- * @license
112
- * Copyright 2017 Google LLC. All Rights Reserved.
113
- * Licensed under the Apache License, Version 2.0 (the "License");
114
- * you may not use this file except in compliance with the License.
115
- * You may obtain a copy of the License at
116
- *
117
- * http://www.apache.org/licenses/LICENSE-2.0
118
- *
119
- * Unless required by applicable law or agreed to in writing, software
120
- * distributed under the License is distributed on an "AS IS" BASIS,
121
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122
- * See the License for the specific language governing permissions and
123
- * limitations under the License.
124
- * =============================================================================
125
- */
126
- class j {
127
- constructor(o, t) {
128
- this.variableNames = ["x"];
129
- const { windowSize: e, batchSize: n, inSize: u, outSize: l } = o;
130
- this.outputShape = [n, l];
131
- let s = "0.0", c = "";
132
- t === "prod" ? s = "1.0" : t === "min" ? (s = "1.0 / 1e-20", c = "min") : t === "max" && (s = "-1.0 / 1e-20", c = "max");
133
- let i = `${t}(${t}(${t}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;
134
- t === "sum" ? i = "sumValue" : t === "prod" ? i = "prodValue" : t === "all" ? i = "allValue" : t === "any" && (i = "anyValue");
135
- const r = Math.floor(e / 4) * 4, p = e % 4;
136
- let h = `
137
- if (${t === "sum"}) {
138
- sumValue += dot(values, ones);
139
- } else if (${t === "prod"}) {
140
- vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
141
- prodValue *= tmp[0] * tmp[1];
142
- } else {
143
- minMaxValue = ${c}(values, minMaxValue);
144
- if (${t === "min"} || ${t === "max"}) {
145
- minMaxValue = ${c}(values, minMaxValue);
146
- bvec4 isNaN = isnan(values);
147
- if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
148
- minMaxValue = vec4(NAN);
149
- }
150
- }
151
- }
152
- `, d = "vec4";
153
- t === "all" ? (s = "1.0", h = `
154
- bool reducedAllValue = all(values);
155
- float floatedReducedAllValue = float(reducedAllValue);
156
- allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
157
- `, d = "bvec4") : t === "any" && (s = "0.0", h = `
158
- bool reducedAnyValue = any(values);
159
- float floatedReducedAnyValue = float(reducedAnyValue);
160
- anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
161
- `, d = "bvec4");
162
- let f = "";
163
- u % e > 0 && (f = `
164
- if (inIdx < 0 || inIdx >= ${u}) {
165
- return initializationValue;
166
- }
167
- `), this.userCode = `
168
- const float initializationValue = ${s};
169
- const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
170
-
171
- float getValue(int batch, int inIdx) {
172
- ${f}
173
- return getX(batch, inIdx);
174
- }
175
-
176
- void main() {
177
- ivec2 coords = getOutputCoords();
178
- int batch = coords[0];
179
- int outIdx = coords[1];
180
- int inOffset = outIdx * ${e};
181
-
182
- vec4 minMaxValue = vec4(${s});
183
- float prodValue = 1.0;
184
- float sumValue = 0.0;
185
- float allValue = 1.0;
186
- float anyValue = 0.0;
187
-
188
- for (int i = 0; i < ${r}; i += 4) {
189
- int inIdx = inOffset + i;
190
- ${d} values = ${d}(
191
- getValue(batch, inIdx),
192
- getValue(batch, inIdx + 1),
193
- getValue(batch, inIdx + 2),
194
- getValue(batch, inIdx + 3)
195
- );
196
-
197
- ${h}
198
- }
199
-
200
- int inIdx = inOffset + ${r};
201
- if (${p === 1}) {
202
- ${d} values = ${d}(
203
- getValue(batch, inIdx),
204
- initializationValue,
205
- initializationValue,
206
- initializationValue
207
- );
208
-
209
- ${h}
210
- } else if (${p === 2}) {
211
- ${d} values = ${d}(
212
- getValue(batch, inIdx),
213
- getValue(batch, inIdx + 1),
214
- initializationValue,
215
- initializationValue
216
- );
217
-
218
- ${h}
219
- } else if (${p === 3}) {
220
- ${d} values = ${d}(
221
- getValue(batch, inIdx),
222
- getValue(batch, inIdx + 1),
223
- getValue(batch, inIdx + 2),
224
- initializationValue
225
- );
226
-
227
- ${h}
228
- }
229
- setOutput(${i});
230
- }
231
- `;
232
- }
233
- }
234
- /**
235
- * @license
236
- * Copyright 2020 Google LLC. All Rights Reserved.
237
- * Licensed under the Apache License, Version 2.0 (the "License");
238
- * you may not use this file except in compliance with the License.
239
- * You may obtain a copy of the License at
240
- *
241
- * http://www.apache.org/licenses/LICENSE-2.0
242
- *
243
- * Unless required by applicable law or agreed to in writing, software
244
- * distributed under the License is distributed on an "AS IS" BASIS,
245
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
246
- * See the License for the specific language governing permissions and
247
- * limitations under the License.
248
- * =============================================================================
249
- */
250
- function H(a) {
251
- const o = [];
252
- for (; o.length === 0 || o[o.length - 1].outSize !== 1; ) {
253
- const t = o.length ? o[o.length - 1].outSize : a[1], e = F(t);
254
- o.push({
255
- inSize: t,
256
- windowSize: e,
257
- outSize: Math.ceil(t / e)
258
- });
259
- }
260
- return o;
261
- }
262
- function N(a, o, t, e) {
263
- const n = H(a.shape);
264
- let u = a;
265
- for (let l = 0; l < n.length; l++) {
266
- const { inSize: s, windowSize: c, outSize: i } = n[l];
267
- let r, p;
268
- t === "mean" ? r = l === 0 ? new A({ windowSize: c, inSize: s, batchSize: a.shape[0], outSize: i }, s) : new A({ windowSize: c, inSize: s, batchSize: a.shape[0], outSize: i }) : r = new j({ windowSize: c, inSize: s, batchSize: a.shape[0], outSize: i }, t), p = u, u = e.runWebGLProgram(r, [u], o), p.dataId !== a.dataId && e.disposeIntermediateTensorInfo(p);
269
- }
270
- return u;
271
- }
272
- /**
273
- * @license
274
- * Copyright 2020 Google LLC. All Rights Reserved.
275
- * Licensed under the Apache License, Version 2.0 (the "License");
276
- * you may not use this file except in compliance with the License.
277
- * You may obtain a copy of the License at
278
- *
279
- * http://www.apache.org/licenses/LICENSE-2.0
280
- *
281
- * Unless required by applicable law or agreed to in writing, software
282
- * distributed under the License is distributed on an "AS IS" BASIS,
283
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
284
- * See the License for the specific language governing permissions and
285
- * limitations under the License.
286
- * =============================================================================
287
- */
288
- function X(a, o, t, e) {
289
- const n = V(o), l = V(a.shape) / n, s = b({ inputs: { x: a }, attrs: { shape: [l, n] }, backend: e }), c = N(s, a.dtype, "max", e), i = b({ inputs: { x: c }, attrs: { shape: t }, backend: e });
290
- return e.disposeIntermediateTensorInfo(s), e.disposeIntermediateTensorInfo(c), i;
291
- }
292
- /**
293
- * @license
294
- * Copyright 2017 Google LLC. All Rights Reserved.
295
- * Licensed under the Apache License, Version 2.0 (the "License");
296
- * you may not use this file except in compliance with the License.
297
- * You may obtain a copy of the License at
298
- *
299
- * http://www.apache.org/licenses/LICENSE-2.0
300
- *
301
- * Unless required by applicable law or agreed to in writing, software
302
- * distributed under the License is distributed on an "AS IS" BASIS,
303
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
304
- * See the License for the specific language governing permissions and
305
- * limitations under the License.
306
- * =============================================================================
307
- */
308
- class Y {
309
- constructor(o, t) {
310
- this.variableNames = ["A"];
311
- const e = new Array(o.length);
312
- for (let l = 0; l < e.length; l++)
313
- e[l] = o[t[l]];
314
- this.outputShape = e, this.rank = e.length;
315
- const n = C(this.rank), u = Z(t);
316
- this.userCode = `
317
- void main() {
318
- ${n} resRC = getOutputCoords();
319
- setOutput(getA(${u}));
320
- }
321
- `;
322
- }
323
- }
324
- function Z(a) {
325
- const o = a.length;
326
- if (o > 6)
327
- throw Error(`Transpose for rank ${o} is not yet supported`);
328
- const t = ["resRC.x", "resRC.y", "resRC.z", "resRC.w", "resRC.u", "resRC.v"], e = new Array(o);
329
- for (let n = 0; n < a.length; n++)
330
- e[a[n]] = t[n];
331
- return e.join();
332
- }
333
- /**
334
- * @license
335
- * Copyright 2019 Google LLC. All Rights Reserved.
336
- * Licensed under the Apache License, Version 2.0 (the "License");
337
- * you may not use this file except in compliance with the License.
338
- * You may obtain a copy of the License at
339
- *
340
- * http://www.apache.org/licenses/LICENSE-2.0
341
- *
342
- * Unless required by applicable law or agreed to in writing, software
343
- * distributed under the License is distributed on an "AS IS" BASIS,
344
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
345
- * See the License for the specific language governing permissions and
346
- * limitations under the License.
347
- * =============================================================================
348
- */
349
- class J {
350
- constructor(o, t) {
351
- this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
352
- const e = new Array(o.length);
353
- for (let r = 0; r < e.length; r++)
354
- e[r] = o[t[r]];
355
- if (this.outputShape = e, this.rank = e.length, this.rank > 6)
356
- throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
357
- const n = C(this.rank), u = U("rc", this.rank), l = new Array(this.rank);
358
- for (let r = 0; r < t.length; r++)
359
- l[t[r]] = u[r];
360
- const s = `vec2(${l.slice(-2).join()})`, c = `++${u[this.rank - 1]} < ${e[this.rank - 1]}`, i = `getChannel(getA(${l.join()}), ${s})`;
361
- this.userCode = `
362
- void main() {
363
- ${n} rc = getOutputCoords();
364
- vec4 result = vec4(0.);
365
- result[0] = ${i};
366
- if(${c}) {
367
- result[1] = ${i};
368
- }
369
- --${u[this.rank - 1]};
370
- if(++${u[this.rank - 2]} < ${e[this.rank - 2]}) {
371
- result[2] = ${i};
372
- if(${c}) {
373
- result[3] = ${i};
374
- }
375
- }
376
- setOutput(result);
377
- }
378
- `;
379
- }
380
- }
381
- /**
382
- * @license
383
- * Copyright 2020 Google LLC. All Rights Reserved.
384
- * Licensed under the Apache License, Version 2.0 (the "License");
385
- * you may not use this file except in compliance with the License.
386
- * You may obtain a copy of the License at
387
- *
388
- * http://www.apache.org/licenses/LICENSE-2.0
389
- *
390
- * Unless required by applicable law or agreed to in writing, software
391
- * distributed under the License is distributed on an "AS IS" BASIS,
392
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
393
- * See the License for the specific language governing permissions and
394
- * limitations under the License.
395
- * =============================================================================
396
- */
397
- function D(a, o, t) {
398
- const e = L().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, o) : new Y(a.shape, o);
399
- return t.runWebGLProgram(e, [a], a.dtype);
400
- }
401
- /**
402
- * @license
403
- * Copyright 2020 Google LLC. All Rights Reserved.
404
- * Licensed under the Apache License, Version 2.0 (the "License");
405
- * you may not use this file except in compliance with the License.
406
- * You may obtain a copy of the License at
407
- *
408
- * http://www.apache.org/licenses/LICENSE-2.0
409
- *
410
- * Unless required by applicable law or agreed to in writing, software
411
- * distributed under the License is distributed on an "AS IS" BASIS,
412
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
413
- * See the License for the specific language governing permissions and
414
- * limitations under the License.
415
- * =============================================================================
416
- */
417
- function Q(a) {
418
- const { inputs: o, backend: t, attrs: e } = a, { x: n } = o, { reductionIndices: u, keepDims: l } = e, s = n.shape.length, c = w(u, n.shape);
419
- let i = c;
420
- const r = k(i, s), p = r != null, h = t.shouldExecuteOnCPU([n]);
421
- let d = n;
422
- if (p) {
423
- if (h) {
424
- const I = t.texData.get(d.dataId).values, g = new Array(s);
425
- for (let $ = 0; $ < g.length; $++)
426
- g[$] = n.shape[r[$]];
427
- const z = B(I, n.shape, n.dtype, r, g);
428
- d = t.makeTensorInfo(g, n.dtype);
429
- const M = t.texData.get(d.dataId);
430
- M.values = z;
431
- } else
432
- d = D(n, r, t);
433
- i = R(i.length, s);
434
- }
435
- P("max", i, s);
436
- const [f, v] = y(d.shape, i);
437
- let x = f;
438
- l && (x = O(f, c));
439
- let m;
440
- if (h) {
441
- const I = t.texData.get(d.dataId).values, g = G(I, V(v), x, n.dtype);
442
- m = t.makeTensorInfo(x, n.dtype);
443
- const z = t.texData.get(m.dataId);
444
- z.values = g;
445
- } else
446
- m = X(d, v, x, t);
447
- return p && t.disposeIntermediateTensorInfo(d), m;
448
- }
449
- /**
450
- * @license
451
- * Copyright 2020 Google LLC. All Rights Reserved.
452
- * Licensed under the Apache License, Version 2.0 (the "License");
453
- * you may not use this file except in compliance with the License.
454
- * You may obtain a copy of the License at
455
- *
456
- * http://www.apache.org/licenses/LICENSE-2.0
457
- *
458
- * Unless required by applicable law or agreed to in writing, software
459
- * distributed under the License is distributed on an "AS IS" BASIS,
460
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
461
- * See the License for the specific language governing permissions and
462
- * limitations under the License.
463
- * =============================================================================
464
- */
465
- function ee(a, o, t, e) {
466
- const n = o, u = a.shape.length, l = w(n, a.shape);
467
- let s = l;
468
- const c = k(s, u), i = c != null;
469
- let r = a;
470
- i && (r = D(a, c, e), s = R(s.length, u)), P("sum", s, u);
471
- const [p, h] = y(r.shape, s);
472
- let d = p;
473
- t && (d = O(p, l));
474
- const f = V(h), x = V(a.shape) / f, m = b({ inputs: { x: r }, attrs: { shape: [x, f] }, backend: e }), S = K(a.dtype), I = N(m, S, "sum", e), g = b({ inputs: { x: I }, attrs: { shape: d }, backend: e });
475
- return e.disposeIntermediateTensorInfo(m), e.disposeIntermediateTensorInfo(I), i && e.disposeIntermediateTensorInfo(r), g;
476
- }
477
- /**
478
- * @license
479
- * Copyright 2020 Google LLC. All Rights Reserved.
480
- * Licensed under the Apache License, Version 2.0 (the "License");
481
- * you may not use this file except in compliance with the License.
482
- * You may obtain a copy of the License at
483
- *
484
- * http://www.apache.org/licenses/LICENSE-2.0
485
- *
486
- * Unless required by applicable law or agreed to in writing, software
487
- * distributed under the License is distributed on an "AS IS" BASIS,
488
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
489
- * See the License for the specific language governing permissions and
490
- * limitations under the License.
491
- * =============================================================================
492
- */
493
- function te(a) {
494
- const { inputs: o, backend: t, attrs: e } = a, { x: n } = o, { axis: u, keepDims: l } = e;
495
- return ee(n, u, l, t);
496
- }
497
- /**
498
- * @license
499
- * Copyright 2020 Google LLC. All Rights Reserved.
500
- * Licensed under the Apache License, Version 2.0 (the "License");
501
- * you may not use this file except in compliance with the License.
502
- * You may obtain a copy of the License at
503
- *
504
- * http://www.apache.org/licenses/LICENSE-2.0
505
- *
506
- * Unless required by applicable law or agreed to in writing, software
507
- * distributed under the License is distributed on an "AS IS" BASIS,
508
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
509
- * See the License for the specific language governing permissions and
510
- * limitations under the License.
511
- * =============================================================================
512
- */
513
- const ae = `
514
- if (a == b) {
515
- return 1.0;
516
- };
517
- return a / b;`, se = `
518
- // vec4 one = vec4(equal(a, b));
519
- // return one + (vec4(1.0) - one) * a / b;
520
- vec4 result = a / b;
521
- if(a.x == b.x) {
522
- result.x = 1.;
523
- }
524
- if(a.y == b.y) {
525
- result.y = 1.;
526
- }
527
- if(a.z == b.z) {
528
- result.z = 1.;
529
- }
530
- if(a.w == b.w) {
531
- result.w = 1.;
532
- }
533
-
534
- return result;
535
- `, oe = _({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 });
536
- class ne {
1
+ import { m as b, s as I, r as k } from "../../RealDiv-7xu-pkZN.js";
2
+ import { r as v } from "../../Reshape-BYC1oUku.js";
3
+ import { r as w, p as P } from "../../index-CamYe_M8.js";
4
+ import { e as S } from "../../axis_util-CCNL7jea.js";
5
+ class T {
537
6
  variableNames = ["logits", "maxLogits"];
538
7
  outputShape;
539
8
  userCode;
540
- constructor(o) {
541
- this.outputShape = o, this.userCode = `
9
+ constructor(t) {
10
+ this.outputShape = t, this.userCode = `
542
11
  void main() {
543
12
  ivec4 coords = getOutputCoords(); // [batch, nh, t1, t2]
544
13
  int b = coords.x;
@@ -552,7 +21,7 @@ class ne {
552
21
  `;
553
22
  }
554
23
  }
555
- class ie {
24
+ class C {
556
25
  variableNames = ["exp", "sum"];
557
26
  outputShape;
558
27
  userCode;
@@ -560,8 +29,8 @@ class ie {
560
29
  { name: "dropoutRate", type: "float" },
561
30
  { name: "seed", type: "float" }
562
31
  ];
563
- constructor(o) {
564
- this.outputShape = o, this.userCode = `
32
+ constructor(t) {
33
+ this.outputShape = t, this.userCode = `
565
34
  float random(ivec4 coords) {
566
35
  float x = float(coords.x * 4096 + coords.y * 256 + coords.z * 16 + coords.w);
567
36
  return fract(sin(seed + x) * 43758.5453123);
@@ -579,33 +48,33 @@ class ie {
579
48
  `;
580
49
  }
581
50
  }
582
- function re(a) {
583
- const { inputs: o, attrs: t } = a, { logits: e } = o, { dim: n, dropoutRate: u, seed: l } = t, s = a.backend;
51
+ function L(r) {
52
+ const { inputs: t, attrs: m } = r, { logits: e } = t, { dim: u, dropoutRate: n, seed: c } = m, o = r.backend;
584
53
  if (!e)
585
54
  throw new Error("Error in softmax: input logits is null");
586
- const c = w([n], e.shape), i = Q({
55
+ const i = P([u], e.shape), d = b({
587
56
  inputs: { x: e },
588
- backend: s,
589
- attrs: { reductionIndices: c, keepDims: !1 }
590
- }), r = O(i.shape, c), p = new ne(e.shape), h = s.runWebGLProgram(p, [e, i], "float32");
591
- s.disposeIntermediateTensorInfo(i);
592
- const d = te({ inputs: { x: h }, backend: s, attrs: { axis: c, keepDims: !1 } }), f = b({ inputs: { x: d }, backend: s, attrs: { shape: r } });
593
- if (u !== void 0 && u > 0) {
594
- const x = new ie(e.shape), m = s.runWebGLProgram(x, [h, f], "float32", [
595
- [u],
596
- [l ?? Math.random() * 1e4]
57
+ backend: o,
58
+ attrs: { reductionIndices: i, keepDims: !1 }
59
+ }), f = S(d.shape, i), l = new T(e.shape), s = o.runWebGLProgram(l, [e, d], "float32");
60
+ o.disposeIntermediateTensorInfo(d);
61
+ const p = I({ inputs: { x: s }, backend: o, attrs: { axis: i, keepDims: !1 } }), a = v({ inputs: { x: p }, backend: o, attrs: { shape: f } });
62
+ if (n !== void 0 && n > 0) {
63
+ const g = new C(e.shape), h = o.runWebGLProgram(g, [s, a], "float32", [
64
+ [n],
65
+ [c ?? Math.random() * 1e4]
597
66
  ]);
598
- return s.disposeIntermediateTensorInfo(h), s.disposeIntermediateTensorInfo(d), s.disposeIntermediateTensorInfo(f), m;
67
+ return o.disposeIntermediateTensorInfo(s), o.disposeIntermediateTensorInfo(p), o.disposeIntermediateTensorInfo(a), h;
599
68
  }
600
- const v = oe({ inputs: { a: h, b: f }, backend: s });
601
- return s.disposeIntermediateTensorInfo(h), s.disposeIntermediateTensorInfo(d), s.disposeIntermediateTensorInfo(f), v;
69
+ const x = k({ inputs: { a: s, b: a }, backend: o });
70
+ return o.disposeIntermediateTensorInfo(s), o.disposeIntermediateTensorInfo(p), o.disposeIntermediateTensorInfo(a), x;
602
71
  }
603
- const ue = {
72
+ const E = {
604
73
  kernelName: "FusedSoftmax",
605
74
  backendName: "webgl",
606
- kernelFunc: re
75
+ kernelFunc: L
607
76
  };
608
- W(ue);
77
+ w(E);
609
78
  export {
610
- re as softmax
79
+ L as softmax
611
80
  };
@@ -1,4 +1,4 @@
1
- import { r as l } from "../../index-bMBtI-WR.js";
1
+ import { r as l } from "../../index-CamYe_M8.js";
2
2
  class u {
3
3
  variableNames = ["labels", "logits", "values"];
4
4
  outputShape;
@@ -1,5 +1,5 @@
1
- import { r as a } from "../../index-bMBtI-WR.js";
2
- import { u as s, C as x } from "../../kernel_funcs_utils-CNmjLWnB.js";
1
+ import { r as a } from "../../index-CamYe_M8.js";
2
+ import { u as s, C as x } from "../../kernel_funcs_utils-D5MS0JFg.js";
3
3
  const t = 0.7978845608028654, r = 0.044715, c = x + `
4
4
  float x3 = x * x * x;
5
5
  float inner = x + ${r} * x3;