@genai-fi/nanogpt 0.4.0 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. package/dist/Generator.js +3 -3
  2. package/dist/NanoGPTModel.js +83 -70
  3. package/dist/TeachableLLM.js +1 -1
  4. package/dist/{random_width-CMHmdbSu.js → TiedEmbedding-CnJ1bx4q.js} +760 -719
  5. package/dist/{axis_util-DeydwOoC.js → axis_util-BgTGy5w8.js} +1 -1
  6. package/dist/{concat-DS_qH7MI.js → concat-CuRsVY-K.js} +1 -1
  7. package/dist/dropout-DfDdklfL.js +193 -0
  8. package/dist/{gather-BUmJIS8n.js → gather-ZYRWhmXR.js} +1 -1
  9. package/dist/gelu-CnCt17Lk.js +26 -0
  10. package/dist/{index-XjBAhiFO.js → index-C4JCoBvj.js} +61 -61
  11. package/dist/kernel_funcs_utils-CAd1h9X1.js +388 -0
  12. package/dist/layers/CausalSelfAttention.js +73 -72
  13. package/dist/layers/MLP.d.ts +3 -1
  14. package/dist/layers/MLP.js +93 -5
  15. package/dist/layers/RMSNorm.js +3 -3
  16. package/dist/layers/RoPECache.js +3 -3
  17. package/dist/layers/TiedEmbedding.js +6 -46
  18. package/dist/layers/TransformerBlock.js +2 -2
  19. package/dist/{log_sum_exp-DJPkVZZn.js → log_sum_exp-BswFnwOb.js} +5 -5
  20. package/dist/main.js +1 -1
  21. package/dist/{mat_mul-CKwFEV1Q.js → mat_mul-415y5Qn2.js} +1 -1
  22. package/dist/{max-DJvEiCAJ.js → max-CP_9O2Yd.js} +1 -1
  23. package/dist/{moments-CrWRPcR3.js → moments-CjeIaVdp.js} +3 -3
  24. package/dist/{norm-BzY929B_.js → norm-CZM380I3.js} +5 -5
  25. package/dist/{ones-BO01zpJG.js → ones-Bf3YR48P.js} +2 -2
  26. package/dist/ops/appendCache.js +1 -1
  27. package/dist/ops/attentionMask.d.ts +1 -1
  28. package/dist/ops/attentionMask.js +4 -4
  29. package/dist/ops/cpu/appendCache.js +2 -2
  30. package/dist/ops/cpu/attentionMask.js +13 -9
  31. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  32. package/dist/ops/cpu/gatherSub.js +3 -3
  33. package/dist/ops/cpu/gelu.d.ts +1 -0
  34. package/dist/ops/cpu/gelu.js +40 -0
  35. package/dist/ops/cpu/mulDropout.js +1 -1
  36. package/dist/ops/cpu/qkv.js +3 -3
  37. package/dist/ops/cpu/rope.js +5 -5
  38. package/dist/ops/cpu/scatterSub.js +4 -4
  39. package/dist/ops/fusedSoftmax.js +1 -1
  40. package/dist/ops/gatherSub.js +1 -1
  41. package/dist/ops/gelu.d.ts +3 -0
  42. package/dist/ops/gelu.js +8 -0
  43. package/dist/ops/grads/attentionMask.js +1 -1
  44. package/dist/ops/grads/fusedSoftmax.js +2 -2
  45. package/dist/ops/grads/gelu.d.ts +2 -0
  46. package/dist/ops/grads/gelu.js +5 -0
  47. package/dist/ops/grads/qkv.js +1 -1
  48. package/dist/ops/grads/rope.js +1 -1
  49. package/dist/ops/mulDrop.js +1 -1
  50. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  51. package/dist/ops/qkv.js +1 -1
  52. package/dist/ops/scatterSub.js +1 -1
  53. package/dist/ops/webgl/appendCache.js +1 -1
  54. package/dist/ops/webgl/attentionMask.js +19 -18
  55. package/dist/ops/webgl/fusedSoftmax.js +489 -788
  56. package/dist/ops/webgl/gatherSub.js +1 -1
  57. package/dist/ops/webgl/gelu.d.ts +2 -0
  58. package/dist/ops/webgl/gelu.js +50 -0
  59. package/dist/ops/webgl/mulDropout.js +1 -1
  60. package/dist/ops/webgl/qkv.js +1 -1
  61. package/dist/ops/webgl/rope.js +1 -1
  62. package/dist/ops/webgl/scatterSub.js +1 -1
  63. package/dist/{range-DQMNzBWs.js → range-9AzeApCc.js} +1 -1
  64. package/dist/{reshape-DFzh97Sc.js → reshape-Boe4DuIO.js} +1 -1
  65. package/dist/{sin-BYM-U4Ut.js → sin-KmhiDuMa.js} +1 -1
  66. package/dist/{slice_util-CnVNPQI-.js → slice_util-19zDNNSn.js} +2 -2
  67. package/dist/{softmax-4DOn6cPq.js → softmax-Cujsg4ay.js} +1 -1
  68. package/dist/{split-CkbeVdF8.js → split-DbcNm1-i.js} +1 -1
  69. package/dist/{stack-DaIMO5iX.js → stack-D1YjmgKN.js} +1 -1
  70. package/dist/{sum-C6u3xMi3.js → sum-R28pucR5.js} +1 -1
  71. package/dist/{tensor-Cu1fU7H7.js → tensor-BVeHdl7V.js} +1 -1
  72. package/dist/{tensor2d-D0CKdG6B.js → tensor2d-DqFGNs_K.js} +1 -1
  73. package/dist/{tfjs_backend-Bzl2SrRo.js → tfjs_backend-Cug-PH75.js} +826 -1015
  74. package/dist/training/AdamExt.js +1 -1
  75. package/dist/training/DatasetBuilder.js +3 -3
  76. package/dist/training/FullTrainer.js +1 -1
  77. package/dist/training/Trainer.js +5 -5
  78. package/dist/training/sparseCrossEntropy.js +4 -4
  79. package/dist/utilities/dummy.js +2 -2
  80. package/dist/utilities/generate.js +3 -3
  81. package/dist/utilities/load.js +1 -1
  82. package/dist/utilities/profile.js +1 -1
  83. package/dist/utilities/weights.js +2 -2
  84. package/dist/{variable-BS4AKqNU.js → variable-LJT9Ld63.js} +1 -1
  85. package/dist/{zeros-CmJFiC84.js → zeros-dnQxFgAD.js} +1 -1
  86. package/package.json +1 -1
  87. package/dist/MLP-KHhikThU.js +0 -83
@@ -0,0 +1,388 @@
1
+ import { aj as E, ak as D, af as B, al as w, n as N, am as v } from "./index-C4JCoBvj.js";
2
+ /**
3
+ * @license
4
+ * Copyright 2018 Google LLC. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ * =============================================================================
17
+ */
18
+ function C(t) {
19
+ try {
20
+ return t.map((e) => E(e));
21
+ } catch (e) {
22
+ throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${e}`);
23
+ }
24
+ }
25
+ function F(t) {
26
+ return t.map((e) => D(e));
27
+ }
28
+ /**
29
+ * @license
30
+ * Copyright 2017 Google LLC. All Rights Reserved.
31
+ * Licensed under the Apache License, Version 2.0 (the "License");
32
+ * you may not use this file except in compliance with the License.
33
+ * You may obtain a copy of the License at
34
+ *
35
+ * http://www.apache.org/licenses/LICENSE-2.0
36
+ *
37
+ * Unless required by applicable law or agreed to in writing, software
38
+ * distributed under the License is distributed on an "AS IS" BASIS,
39
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40
+ * See the License for the specific language governing permissions and
41
+ * limitations under the License.
42
+ * =============================================================================
43
+ */
44
+ function R(t) {
45
+ if (t <= 1)
46
+ return "int";
47
+ if (t === 2)
48
+ return "ivec2";
49
+ if (t === 3)
50
+ return "ivec3";
51
+ if (t === 4)
52
+ return "ivec4";
53
+ if (t === 5)
54
+ return "ivec5";
55
+ if (t === 6)
56
+ return "ivec6";
57
+ throw Error(`GPU for rank ${t} is not yet supported`);
58
+ }
59
+ /**
60
+ * @license
61
+ * Copyright 2017 Google LLC. All Rights Reserved.
62
+ * Licensed under the Apache License, Version 2.0 (the "License");
63
+ * you may not use this file except in compliance with the License.
64
+ * You may obtain a copy of the License at
65
+ *
66
+ * http://www.apache.org/licenses/LICENSE-2.0
67
+ *
68
+ * Unless required by applicable law or agreed to in writing, software
69
+ * distributed under the License is distributed on an "AS IS" BASIS,
70
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
71
+ * See the License for the specific language governing permissions and
72
+ * limitations under the License.
73
+ * =============================================================================
74
+ */
75
+ function y(t) {
76
+ return B().getBool("WEBGL_USE_SHAPES_UNIFORMS") && t <= 4;
77
+ }
78
+ /**
79
+ * @license
80
+ * Copyright 2018 Google LLC. All Rights Reserved.
81
+ * Licensed under the Apache License, Version 2.0 (the "License");
82
+ * you may not use this file except in compliance with the License.
83
+ * You may obtain a copy of the License at
84
+ *
85
+ * http://www.apache.org/licenses/LICENSE-2.0
86
+ *
87
+ * Unless required by applicable law or agreed to in writing, software
88
+ * distributed under the License is distributed on an "AS IS" BASIS,
89
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90
+ * See the License for the specific language governing permissions and
91
+ * limitations under the License.
92
+ * =============================================================================
93
+ */
94
+ function k(t, e) {
95
+ return ["x", "y", "z", "w", "u", "v"].slice(0, e).map((a) => `${t}.${a}`);
96
+ }
97
+ function _(t, e) {
98
+ return e === 1 ? [t] : k(t, e);
99
+ }
100
+ /**
101
+ * @license
102
+ * Copyright 2017 Google LLC. All Rights Reserved.
103
+ * Licensed under the Apache License, Version 2.0 (the "License");
104
+ * you may not use this file except in compliance with the License.
105
+ * You may obtain a copy of the License at
106
+ *
107
+ * http://www.apache.org/licenses/LICENSE-2.0
108
+ *
109
+ * Unless required by applicable law or agreed to in writing, software
110
+ * distributed under the License is distributed on an "AS IS" BASIS,
111
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
112
+ * See the License for the specific language governing permissions and
113
+ * limitations under the License.
114
+ * =============================================================================
115
+ */
116
+ class A {
117
+ constructor(e, a, u) {
118
+ this.variableNames = ["A", "B"], this.outputShape = w(a, u), this.enableShapeUniforms = y(this.outputShape.length), this.userCode = `
119
+ float binaryOperation(float a, float b) {
120
+ ${e}
121
+ }
122
+
123
+ void main() {
124
+ float a = getAAtOutCoords();
125
+ float b = getBAtOutCoords();
126
+ setOutput(binaryOperation(a, b));
127
+ }
128
+ `;
129
+ }
130
+ }
131
+ /**
132
+ * @license
133
+ * Copyright 2018 Google LLC. All Rights Reserved.
134
+ * Licensed under the Apache License, Version 2.0 (the "License");
135
+ * you may not use this file except in compliance with the License.
136
+ * You may obtain a copy of the License at
137
+ *
138
+ * http://www.apache.org/licenses/LICENSE-2.0
139
+ *
140
+ * Unless required by applicable law or agreed to in writing, software
141
+ * distributed under the License is distributed on an "AS IS" BASIS,
142
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
143
+ * See the License for the specific language governing permissions and
144
+ * limitations under the License.
145
+ * =============================================================================
146
+ */
147
+ class z {
148
+ constructor(e, a, u, d = !1) {
149
+ this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(a, u);
150
+ const o = this.outputShape.length;
151
+ this.enableShapeUniforms = y(o);
152
+ let n = "";
153
+ if (d)
154
+ if (o === 0 || N(this.outputShape) === 1)
155
+ n = `
156
+ result.y = 0.;
157
+ result.z = 0.;
158
+ result.w = 0.;
159
+ `;
160
+ else if (n = `
161
+ ${R(o)} coords = getOutputCoords();
162
+ `, o === 1)
163
+ this.enableShapeUniforms ? n += `
164
+ result.y = (coords + 1) >= outShape ? 0. : result.y;
165
+ result.z = 0.;
166
+ result.w = 0.;
167
+ ` : n += `
168
+ result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
169
+ result.z = 0.;
170
+ result.w = 0.;
171
+ `;
172
+ else {
173
+ const s = _("coords", o);
174
+ this.enableShapeUniforms ? n += `
175
+ bool nextRowOutOfBounds =
176
+ (${s[o - 2]} + 1) >= outShape[${o} - 2];
177
+ bool nextColOutOfBounds =
178
+ (${s[o - 1]} + 1) >= outShape[${o} - 1];
179
+ result.y = nextColOutOfBounds ? 0. : result.y;
180
+ result.z = nextRowOutOfBounds ? 0. : result.z;
181
+ result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
182
+ ` : n += `
183
+ bool nextRowOutOfBounds =
184
+ (${s[o - 2]} + 1) >= ${this.outputShape[o - 2]};
185
+ bool nextColOutOfBounds =
186
+ (${s[o - 1]} + 1) >= ${this.outputShape[o - 1]};
187
+ result.y = nextColOutOfBounds ? 0. : result.y;
188
+ result.z = nextRowOutOfBounds ? 0. : result.z;
189
+ result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
190
+ `;
191
+ }
192
+ this.userCode = `
193
+ vec4 binaryOperation(vec4 a, vec4 b) {
194
+ ${e}
195
+ }
196
+
197
+ void main() {
198
+ vec4 a = getAAtOutCoords();
199
+ vec4 b = getBAtOutCoords();
200
+
201
+ vec4 result = binaryOperation(a, b);
202
+ ${n}
203
+
204
+ setOutput(result);
205
+ }
206
+ `;
207
+ }
208
+ }
209
+ /**
210
+ * @license
211
+ * Copyright 2020 Google LLC. All Rights Reserved.
212
+ * Licensed under the Apache License, Version 2.0 (the "License");
213
+ * you may not use this file except in compliance with the License.
214
+ * You may obtain a copy of the License at
215
+ *
216
+ * http://www.apache.org/licenses/LICENSE-2.0
217
+ *
218
+ * Unless required by applicable law or agreed to in writing, software
219
+ * distributed under the License is distributed on an "AS IS" BASIS,
220
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
221
+ * See the License for the specific language governing permissions and
222
+ * limitations under the License.
223
+ * =============================================================================
224
+ */
225
+ function P(t) {
226
+ const { inputs: e, backend: a } = t, { x: u } = e;
227
+ return a.incRef(u.dataId), { dataId: u.dataId, shape: u.shape, dtype: u.dtype };
228
+ }
229
+ /**
230
+ * @license
231
+ * Copyright 2020 Google LLC. All Rights Reserved.
232
+ * Licensed under the Apache License, Version 2.0 (the "License");
233
+ * you may not use this file except in compliance with the License.
234
+ * You may obtain a copy of the License at
235
+ *
236
+ * http://www.apache.org/licenses/LICENSE-2.0
237
+ *
238
+ * Unless required by applicable law or agreed to in writing, software
239
+ * distributed under the License is distributed on an "AS IS" BASIS,
240
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
241
+ * See the License for the specific language governing permissions and
242
+ * limitations under the License.
243
+ * =============================================================================
244
+ */
245
+ function G(t) {
246
+ const { inputs: e, backend: a } = t, { real: u, imag: d } = e, o = a.makeTensorInfo(u.shape, "complex64"), n = a.texData.get(o.dataId), l = P({ inputs: { x: u }, backend: a }), s = P({ inputs: { x: d }, backend: a });
247
+ return n.complexTensorInfos = { real: l, imag: s }, o;
248
+ }
249
+ /**
250
+ * @license
251
+ * Copyright 2017 Google LLC. All Rights Reserved.
252
+ * Licensed under the Apache License, Version 2.0 (the "License");
253
+ * you may not use this file except in compliance with the License.
254
+ * You may obtain a copy of the License at
255
+ *
256
+ * http://www.apache.org/licenses/LICENSE-2.0
257
+ *
258
+ * Unless required by applicable law or agreed to in writing, software
259
+ * distributed under the License is distributed on an "AS IS" BASIS,
260
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
261
+ * See the License for the specific language governing permissions and
262
+ * limitations under the License.
263
+ * =============================================================================
264
+ */
265
+ class V {
266
+ constructor(e, a) {
267
+ this.variableNames = ["A"], this.outputShape = e, this.enableShapeUniforms = y(this.outputShape.length), this.userCode = `
268
+ float unaryOperation(float x) {
269
+ ${a}
270
+ }
271
+
272
+ void main() {
273
+ float x = getAAtOutCoords();
274
+ float y = unaryOperation(x);
275
+
276
+ setOutput(y);
277
+ }
278
+ `;
279
+ }
280
+ }
281
+ const H = "if (isnan(x)) return x;";
282
+ /**
283
+ * @license
284
+ * Copyright 2018 Google LLC. All Rights Reserved.
285
+ * Licensed under the Apache License, Version 2.0 (the "License");
286
+ * you may not use this file except in compliance with the License.
287
+ * You may obtain a copy of the License at
288
+ *
289
+ * http://www.apache.org/licenses/LICENSE-2.0
290
+ *
291
+ * Unless required by applicable law or agreed to in writing, software
292
+ * distributed under the License is distributed on an "AS IS" BASIS,
293
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
294
+ * See the License for the specific language governing permissions and
295
+ * limitations under the License.
296
+ * =============================================================================
297
+ */
298
+ class L {
299
+ constructor(e, a) {
300
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = e, this.enableShapeUniforms = y(this.outputShape.length), this.userCode = `
301
+ vec4 unaryOperation(vec4 x) {
302
+ ${a}
303
+ }
304
+
305
+ void main() {
306
+ vec4 x = getAAtOutCoords();
307
+ vec4 y = unaryOperation(x);
308
+
309
+ setOutput(y);
310
+ }
311
+ `;
312
+ }
313
+ }
314
+ /**
315
+ * @license
316
+ * Copyright 2020 Google LLC. All Rights Reserved.
317
+ * Licensed under the Apache License, Version 2.0 (the "License");
318
+ * you may not use this file except in compliance with the License.
319
+ * You may obtain a copy of the License at
320
+ *
321
+ * http://www.apache.org/licenses/LICENSE-2.0
322
+ *
323
+ * Unless required by applicable law or agreed to in writing, software
324
+ * distributed under the License is distributed on an "AS IS" BASIS,
325
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
326
+ * See the License for the specific language governing permissions and
327
+ * limitations under the License.
328
+ * =============================================================================
329
+ */
330
+ function K({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: a, dtype: u }) {
331
+ return ({ inputs: d, backend: o }) => {
332
+ const { x: n } = d, l = o, s = u || n.dtype;
333
+ if (l.shouldExecuteOnCPU([n]) && a != null) {
334
+ const p = l.texData.get(n.dataId), x = a(p.values, s);
335
+ return l.makeTensorInfo(n.shape, s, x);
336
+ }
337
+ const i = B().getBool("WEBGL_PACK_UNARY_OPERATIONS") && e != null;
338
+ let r;
339
+ return i ? r = new L(n.shape, e) : r = new V(n.shape, t), l.runWebGLProgram(r, [n], s);
340
+ };
341
+ }
342
+ function Y({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: a = !1, supportsComplex: u = !1, cpuKernelImpl: d, dtype: o }) {
343
+ return ({ inputs: n, backend: l }) => {
344
+ const { a: s, b: i } = n, r = l;
345
+ if (u && s.dtype === "complex64") {
346
+ const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [g, m] = [
347
+ [h.complexTensorInfos.real, f.complexTensorInfos.real],
348
+ [h.complexTensorInfos.imag, f.complexTensorInfos.imag]
349
+ ].map((S) => {
350
+ const [c, O] = S, $ = {
351
+ dataId: c.dataId,
352
+ dtype: c.dtype,
353
+ shape: s.shape
354
+ }, T = {
355
+ dataId: O.dataId,
356
+ dtype: O.dtype,
357
+ shape: i.shape
358
+ }, U = new A(t, s.shape, i.shape);
359
+ return r.runWebGLProgram(U, [$, T], v(c.dtype, O.dtype));
360
+ }), I = G({ inputs: { real: g, imag: m }, backend: r });
361
+ return r.disposeIntermediateTensorInfo(g), r.disposeIntermediateTensorInfo(m), I;
362
+ }
363
+ const p = o || v(s.dtype, i.dtype);
364
+ if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && d != null) {
365
+ const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, g = s.dtype === "string" ? (
366
+ // tslint:disable-next-line: no-any
367
+ C(h)
368
+ ) : h, m = s.dtype === "string" ? (
369
+ // tslint:disable-next-line: no-any
370
+ C(f)
371
+ ) : f, [I, S] = d(s.shape, i.shape, g, m, p), c = r.makeTensorInfo(S, p), O = r.texData.get(c.dataId);
372
+ return O.values = I, c;
373
+ }
374
+ const x = B().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
375
+ let b;
376
+ return x ? b = new z(e, s.shape, i.shape, a) : b = new A(t, s.shape, i.shape), r.runWebGLProgram(b, [s, i], p);
377
+ };
378
+ }
379
+ export {
380
+ H as C,
381
+ F as a,
382
+ y as b,
383
+ k as c,
384
+ Y as d,
385
+ C as f,
386
+ R as g,
387
+ K as u
388
+ };
@@ -1,17 +1,18 @@
1
- import { attentionMask as I } from "../ops/attentionMask.js";
2
- import y from "./BaseLayer.js";
3
- import { qkv as z } from "../ops/qkv.js";
4
- import { rope as P } from "../ops/rope.js";
1
+ import { attentionMask as P } from "../ops/attentionMask.js";
2
+ import T from "./BaseLayer.js";
3
+ import { qkv as y } from "../ops/qkv.js";
4
+ import { rope as w } from "../ops/rope.js";
5
5
  import { appendCache as E } from "../ops/appendCache.js";
6
- import { D as $, F as _, t as x, c as L, e as v, H as W } from "../index-XjBAhiFO.js";
7
- import { fusedSoftmax as S } from "../ops/fusedSoftmax.js";
8
- import { l as M, w as O, r as T, d as N, a as U } from "../tfjs_backend-Bzl2SrRo.js";
9
- import { o as q } from "../ones-BO01zpJG.js";
10
- import { z as B } from "../zeros-CmJFiC84.js";
11
- import { v as g } from "../variable-BS4AKqNU.js";
12
- import { m as C } from "../mat_mul-CKwFEV1Q.js";
13
- import { r as D } from "../reshape-DFzh97Sc.js";
14
- class nt extends y {
6
+ import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index-C4JCoBvj.js";
7
+ import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
8
+ import { l as W, w as M, d as x } from "../tfjs_backend-Cug-PH75.js";
9
+ import { o as N } from "../ones-Bf3YR48P.js";
10
+ import { z as q } from "../zeros-dnQxFgAD.js";
11
+ import { v as k } from "../variable-LJT9Ld63.js";
12
+ import { r as C, d as I } from "../dropout-DfDdklfL.js";
13
+ import { r as B } from "../reshape-Boe4DuIO.js";
14
+ import { m as F } from "../mat_mul-415y5Qn2.js";
15
+ class nt extends T {
15
16
  cAttn = null;
16
17
  cProj = null;
17
18
  bias;
@@ -22,17 +23,17 @@ class nt extends y {
22
23
  units;
23
24
  projUnits;
24
25
  constructor(t, s) {
25
- super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = M.bandPart(q([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
26
- const e = B([s.gpt.blockSize, s.gpt.blockSize]), n = $([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
27
- this.maskInf = O(this.bias, e, n);
26
+ super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(N([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
27
+ const e = q([s.gpt.blockSize, s.gpt.blockSize]), i = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
+ this.maskInf = M(this.bias, e, i);
28
29
  }
29
30
  build() {
30
- this.cAttn === null && (this.cAttn = g(
31
- T([this.config.gpt.nEmbed, this.units], 0, 0.02),
31
+ this.cAttn === null && (this.cAttn = k(
32
+ C([this.config.gpt.nEmbed, this.units], 0, 0.02),
32
33
  !0
33
34
  //`block_${this.index}_attn_cAttn_kernel`
34
- )), this.cProj === null && (this.cProj = g(
35
- T([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
35
+ )), this.cProj === null && (this.cProj = k(
36
+ C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
36
37
  !0
37
38
  //`block_${this.index}_attn_cProj_kernel`
38
39
  ));
@@ -55,57 +56,54 @@ class nt extends y {
55
56
  const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
56
57
  if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
57
58
  if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
58
- this.cAttn ? this.cAttn.assign(s) : this.cAttn = g(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = g(e, !0);
59
+ this.cAttn ? this.cAttn.assign(s) : this.cAttn = k(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = k(e, !0);
59
60
  }
60
- getAttentionScores(t, s, e, n) {
61
- const o = I(t, s, this.maskInf, this.divisor);
62
- return S(o, e ? this.config.gpt.dropout : 0, n);
61
+ getAttentionScores(t, s, e, i) {
62
+ const o = P(t, s, this.divisor, this.maskInf);
63
+ return _(o, e ? this.config.gpt.dropout : 0, i);
63
64
  }
64
65
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
65
- getAttentionScoresWithPast(t, s, e, n, o) {
66
- const i = t.shape[2];
67
- let r = C(t, s, !1, !0).mul(this.divisor);
68
- if (i > 1 && n > 0)
69
- throw new Error("Cannot use past with T_cur > 1");
70
- if (i > 1) {
71
- const c = this.maskInf.slice([0, 0], [i, i]).expandDims(0).expandDims(0);
72
- r = r.add(c);
73
- }
74
- return S(r, e ? this.config.gpt.dropout : 0, o);
66
+ getAttentionScoresWithPast(t, s, e) {
67
+ const i = P(t, s, this.divisor, void 0, e);
68
+ return _(i, 0, 0);
75
69
  }
76
70
  getQKV(t) {
77
- return z(t, this.cAttn, this.config.gpt.nHead);
71
+ return y(t, this.cAttn, this.config.gpt.nHead);
78
72
  }
79
73
  getOutputProjection(t) {
80
- const s = t.shape[0], e = t.shape[2], n = this.config.gpt.nEmbed, o = t.transpose([0, 2, 1, 3]), i = D(o, [s, e, n]);
81
- return N(i, this.cProj);
74
+ const s = t.shape[0], e = t.shape[2], i = this.config.gpt.nEmbed, o = t.transpose([0, 2, 1, 3]), n = B(o, [s, e, i]);
75
+ return x(n, this.cProj);
82
76
  }
83
77
  updateCache(t, s, e) {
84
- const n = this.config.gpt.blockSize, o = t.shape[2], i = Math.min(e?.length || 0, n - o), a = e ? E(e.k, t, n) : t, r = e ? E(e.v, s, n) : s;
78
+ const i = this.config.gpt.blockSize, o = t.shape[2], n = Math.min(e?.length || 0, i - o), r = e ? E(e.k, t, i) : t, a = e ? E(e.v, s, i) : s;
85
79
  return {
86
- k: _(a),
87
- v: _(r),
88
- length: i + o,
80
+ k: S(r),
81
+ v: S(a),
82
+ length: n + o,
89
83
  cumulativeLength: e ? e.cumulativeLength + o : o
90
84
  };
91
85
  }
92
- forward(t, s = !1, e, n = !1, o) {
93
- return x(() => {
86
+ forward(t, s = !1, e, i = !1, o) {
87
+ return $(() => {
94
88
  this.startMemory();
95
- const [i, a, r] = this.getQKV(t), c = o ? o.cumulativeLength : 0, h = this.config.layerConfig.ropeCache, d = h ? P(i, h, c) : i, p = h ? P(a, h, c) : a;
96
- h && (i.dispose(), a.dispose());
97
- const f = o ? o.length : 0, l = this.updateCache(p, r, o), m = l.k, b = l.v;
98
- o && (p.dispose(), r.dispose());
99
- let u;
100
- f > 0 ? u = this.getAttentionScoresWithPast(d, m, s, f, e) : u = this.getAttentionScores(d, m, s, e);
101
- const k = C(u, b), A = this.getOutputProjection(k), w = n ? u.mean(1) : void 0;
102
- return this.endMemory("CausalSelfAttention"), { output: A, attention: w, presentKV: l };
89
+ const [n, r, a] = this.getQKV(t), p = o ? o.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, f = c ? w(r, c, p) : r;
90
+ c && (n.dispose(), r.dispose());
91
+ const g = o ? o.length : 0, d = this.updateCache(f, a, o), l = d.k, m = d.v;
92
+ o && (f.dispose(), a.dispose());
93
+ let h;
94
+ g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s, e), u.dispose(), s && l.dispose();
95
+ const b = F(h, m);
96
+ i || h.dispose(), s && m.dispose();
97
+ const A = this.getOutputProjection(b);
98
+ b.dispose();
99
+ const v = i ? h.mean(1) : void 0;
100
+ return this.endMemory("CausalSelfAttention"), { output: A, attention: v, presentKV: s ? void 0 : d };
103
101
  });
104
102
  }
105
- call(t, s = !1, e = !1, n) {
106
- if (n && !this.config.gpt.useRope)
103
+ call(t, s = !1, e = !1, i) {
104
+ if (i && !this.config.gpt.useRope)
107
105
  throw new Error("Cannot use pastKV without RoPE enabled");
108
- if (s && n)
106
+ if (s && i)
109
107
  throw new Error("Cannot use pastKV during training");
110
108
  if (t.shape.length !== 3)
111
109
  throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
@@ -114,30 +112,33 @@ class nt extends y {
114
112
  this.build();
115
113
  const o = Math.random() * 1e9;
116
114
  if (s && this.config.layerConfig.checkpointAttention) {
117
- const a = L(
115
+ const r = L(
118
116
  // @ts-expect-error Invalid params
119
- (r, c, h, d) => {
120
- const p = this.forward(r, !0, o);
121
- p.presentKV?.k.dispose(), p.presentKV?.v.dispose(), d([r]);
122
- const f = (l, m) => {
123
- const [b] = m, u = v().state.activeTape;
124
- v().state.activeTape = [];
125
- const k = W((A, w, H) => {
126
- const j = this.forward(A, !0, o);
127
- return j.presentKV?.k.dispose(), j.presentKV?.v.dispose(), j.output;
128
- })([b, c, h], l);
129
- return v().state.activeTape = u, k;
117
+ (a, p, c, u) => {
118
+ const f = this.forward(a, !0, o);
119
+ u([a]);
120
+ const g = (d, l) => {
121
+ const [m] = l, h = j().state.activeTape;
122
+ j().state.activeTape = [];
123
+ const b = O((A, v, R) => this.forward(A, !0, o).output)([m, p, c], d);
124
+ return j().state.activeTape = h, b;
130
125
  };
131
- return { value: p.output, gradFunc: f };
126
+ return { value: f.output, gradFunc: g };
132
127
  }
133
128
  )(t, this.cAttn, this.cProj);
134
129
  if (this.config.gpt.dropout > 0) {
135
- const r = U(a, this.config.gpt.dropout);
136
- return a.dispose(), { output: r };
130
+ const a = I(r, this.config.gpt.dropout);
131
+ return r.dispose(), { output: a };
132
+ } else
133
+ return { output: r };
134
+ } else {
135
+ const n = this.forward(t, s, o, e, i);
136
+ if (this.config.gpt.dropout > 0) {
137
+ const r = I(n.output, this.config.gpt.dropout);
138
+ return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
137
139
  } else
138
- return { output: a };
139
- } else
140
- return this.forward(t, s, o, e, n);
140
+ return n;
141
+ }
141
142
  }
142
143
  dispose() {
143
144
  this.cAttn?.dispose(), this.cProj?.dispose(), this.bias.dispose(), this.maskInf.dispose();
@@ -3,15 +3,17 @@ import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
3
3
  export default class MLP extends BaseLayer {
4
4
  private cFc;
5
5
  private cProj;
6
- private dropout;
7
6
  private index;
8
7
  private _trainable;
8
+ private hiddenUnits;
9
9
  constructor(index: number, config: GPTLayerConfig);
10
+ private build;
10
11
  get variables(): Variable[];
11
12
  get trainable(): boolean;
12
13
  set trainable(value: boolean);
13
14
  saveWeights(map: Map<string, Tensor[]>): void;
14
15
  loadWeights(weights: Map<string, Tensor[]>): void;
16
+ forward(x: Tensor): Tensor;
15
17
  call(x: Tensor, training?: boolean): Tensor;
16
18
  dispose(): void;
17
19
  }