@genai-fi/nanogpt 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/dist/Generator.js +95 -46
  2. package/dist/NanoGPTModel.d.ts +3 -2
  3. package/dist/NanoGPTModel.js +91 -76
  4. package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
  5. package/dist/TeachableLLM.js +1 -1
  6. package/dist/TiedEmbedding-DORsPlNL.js +44 -0
  7. package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
  8. package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
  9. package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
  10. package/dist/dataset-ZHEPJmED.js +1226 -0
  11. package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
  12. package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
  13. package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
  14. package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
  15. package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
  16. package/dist/layers/BaseLayer.js +114 -3
  17. package/dist/layers/CausalSelfAttention.d.ts +2 -3
  18. package/dist/layers/CausalSelfAttention.js +31 -30
  19. package/dist/layers/MLP.js +10 -9
  20. package/dist/layers/RMSNorm.js +12 -11
  21. package/dist/layers/RoPECache.js +3 -3
  22. package/dist/layers/TiedEmbedding.js +8 -6
  23. package/dist/layers/TransformerBlock.js +2 -2
  24. package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
  25. package/dist/main.js +1 -1
  26. package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
  27. package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
  28. package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
  29. package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
  30. package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
  31. package/dist/ops/appendCache.js +3 -3
  32. package/dist/ops/attentionMask.js +1 -1
  33. package/dist/ops/cpu/appendCache.js +2 -2
  34. package/dist/ops/cpu/attentionMask.js +5 -5
  35. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  36. package/dist/ops/cpu/gatherSub.js +5 -5
  37. package/dist/ops/cpu/gelu.js +1 -1
  38. package/dist/ops/cpu/matMulGelu.js +1 -1
  39. package/dist/ops/cpu/matMulMul.js +1 -1
  40. package/dist/ops/cpu/mulDropout.js +1 -1
  41. package/dist/ops/cpu/normRMS.js +1 -1
  42. package/dist/ops/cpu/qkv.js +3 -3
  43. package/dist/ops/cpu/rope.js +5 -5
  44. package/dist/ops/cpu/scatterSub.js +27 -27
  45. package/dist/ops/fusedSoftmax.js +1 -1
  46. package/dist/ops/gatherSub.js +1 -1
  47. package/dist/ops/gelu.js +1 -1
  48. package/dist/ops/grads/attentionMask.js +1 -1
  49. package/dist/ops/grads/fusedSoftmax.js +2 -2
  50. package/dist/ops/grads/gelu.js +1 -1
  51. package/dist/ops/grads/matMulGelu.js +1 -1
  52. package/dist/ops/grads/normRMS.js +1 -1
  53. package/dist/ops/grads/qkv.js +1 -1
  54. package/dist/ops/grads/rope.js +1 -1
  55. package/dist/ops/matMulGelu.js +1 -1
  56. package/dist/ops/matMulMul.js +1 -1
  57. package/dist/ops/mulDrop.js +1 -1
  58. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  59. package/dist/ops/normRMS.js +1 -1
  60. package/dist/ops/qkv.js +1 -1
  61. package/dist/ops/scatterSub.js +1 -1
  62. package/dist/ops/webgl/appendCache.js +1 -1
  63. package/dist/ops/webgl/attentionMask.js +1 -1
  64. package/dist/ops/webgl/fusedSoftmax.js +36 -36
  65. package/dist/ops/webgl/gatherSub.js +1 -1
  66. package/dist/ops/webgl/gelu.js +2 -2
  67. package/dist/ops/webgl/matMulGelu.js +22 -22
  68. package/dist/ops/webgl/matMulMul.js +1 -1
  69. package/dist/ops/webgl/mulDropout.js +1 -1
  70. package/dist/ops/webgl/normRMS.js +2 -2
  71. package/dist/ops/webgl/qkv.js +1 -1
  72. package/dist/ops/webgl/rope.js +1 -1
  73. package/dist/ops/webgl/scatterSub.js +1 -1
  74. package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
  75. package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
  76. package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
  77. package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
  78. package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
  79. package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
  80. package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
  81. package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
  82. package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
  83. package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
  84. package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
  85. package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
  86. package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
  87. package/dist/tokeniser/CharTokeniser.js +27 -27
  88. package/dist/tokeniser/bpe.d.ts +1 -0
  89. package/dist/tokeniser/bpe.js +38 -35
  90. package/dist/training/AdamExt.js +1 -1
  91. package/dist/training/DatasetBuilder.js +22 -1242
  92. package/dist/training/FullTrainer.js +1 -1
  93. package/dist/training/Trainer.js +5 -5
  94. package/dist/training/sparseCrossEntropy.js +4 -4
  95. package/dist/utilities/dummy.js +2 -2
  96. package/dist/utilities/generate.js +3 -3
  97. package/dist/utilities/load.js +1 -1
  98. package/dist/utilities/profile.js +1 -1
  99. package/dist/utilities/save.js +5 -5
  100. package/dist/utilities/weights.js +2 -2
  101. package/dist/variable-BGvK-VN3.js +23 -0
  102. package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
  103. package/package.json +1 -1
  104. package/dist/BaseLayer-BhrMN8JO.js +0 -135
@@ -1,4 +1,4 @@
1
- import { o as l, i as h, E as m, ag as p, l as c, ah as d, ae as g, k as u, V, ai as v, a9 as N, b as w } from "./index-iNhkcAEQ.js";
1
+ import { o as l, j as h, E as m, ak as p, n as c, al as d, ae as g, l as u, T as V, am as v, a9 as N, b as w } from "./index-CnHyhpKc.js";
2
2
  import { s as f } from "./index-C4L8Cm77.js";
3
3
  /**
4
4
  * @license
@@ -1,4 +1,4 @@
1
- import { o as g, i as t, E as h, G as p } from "./index-iNhkcAEQ.js";
1
+ import { o as g, j as t, E as h, G as p } from "./index-CnHyhpKc.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -16,8 +16,8 @@ import { o as g, i as t, E as h, G as p } from "./index-iNhkcAEQ.js";
16
16
  * =============================================================================
17
17
  */
18
18
  function u(n, s, r = 0, e = 0) {
19
- const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"), i = { x: o, indices: a }, c = { axis: r, batchDims: e };
20
- return h.runKernel(p, i, c);
19
+ const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"), c = { x: o, indices: a }, i = { axis: r, batchDims: e };
20
+ return h.runKernel(p, c, i);
21
21
  }
22
22
  const d = /* @__PURE__ */ g({ gather_: u });
23
23
  export {
@@ -1,4 +1,4 @@
1
- import { K as e } from "./index-iNhkcAEQ.js";
1
+ import { N as e } from "./index-CnHyhpKc.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -4005,26 +4005,26 @@ export {
4005
4005
  Ss as A,
4006
4006
  Zs as B,
4007
4007
  or as C,
4008
- Wa as D,
4008
+ Ft as D,
4009
4009
  g as E,
4010
- Bn as F,
4010
+ Wa as F,
4011
4011
  Pr as G,
4012
- Fs as H,
4013
- kn as I,
4014
- En as J,
4015
- k as K,
4016
- Lr as L,
4012
+ Bn as H,
4013
+ Fs as I,
4014
+ kn as J,
4015
+ En as K,
4016
+ Qa as L,
4017
4017
  ta as M,
4018
- rs as N,
4019
- de as O,
4018
+ k as N,
4019
+ Lr as O,
4020
4020
  ba as P,
4021
- Ea as Q,
4021
+ rs as Q,
4022
4022
  Ia as R,
4023
4023
  qa as S,
4024
- Qa as T,
4025
- Zt as U,
4026
- D as V,
4027
- To as W,
4024
+ D as T,
4025
+ de as U,
4026
+ Ea as V,
4027
+ Zt as W,
4028
4028
  De as X,
4029
4029
  ar as Y,
4030
4030
  ne as Z,
@@ -4074,13 +4074,13 @@ export {
4074
4074
  $t as ad,
4075
4075
  Rt as ae,
4076
4076
  Rs as af,
4077
- xr as ag,
4078
- Wn as ah,
4079
- x as ai,
4080
- F as aj,
4081
- pe as ak,
4082
- fo as al,
4083
- dt as am,
4077
+ F as ag,
4078
+ pe as ah,
4079
+ fo as ai,
4080
+ dt as aj,
4081
+ xr as ak,
4082
+ Wn as al,
4083
+ x as am,
4084
4084
  jt as an,
4085
4085
  ue as ao,
4086
4086
  za as ap,
@@ -4214,22 +4214,22 @@ export {
4214
4214
  K as f,
4215
4215
  ss as g,
4216
4216
  lo as h,
4217
- T as i,
4218
- In as j,
4219
- y as k,
4220
- xt as l,
4217
+ To as i,
4218
+ T as j,
4219
+ In as k,
4220
+ y as l,
4221
4221
  po as m,
4222
- Ge as n,
4222
+ xt as n,
4223
4223
  N as o,
4224
- z as p,
4225
- q,
4224
+ Ge as p,
4225
+ z as q,
4226
4226
  co as r,
4227
4227
  tt as s,
4228
4228
  E as t,
4229
- Ba as u,
4229
+ q as u,
4230
4230
  ls as v,
4231
- Ka as w,
4232
- qn as x,
4233
- Ft as y,
4231
+ Ba as w,
4232
+ Ka as x,
4233
+ qn as y,
4234
4234
  C as z
4235
4235
  };
@@ -1,5 +1,5 @@
1
- import { an as D, ao as N, N as w, p as R, O as v, K as P } from "./index-iNhkcAEQ.js";
2
- import { u as g } from "./gpgpu_math-C0zyxKFi.js";
1
+ import { an as D, ao as N, Q as w, q as R, U as v, N as P } from "./index-CnHyhpKc.js";
2
+ import { u as g } from "./gpgpu_math-Df7gzJWH.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -23,7 +23,7 @@ function B(t) {
23
23
  throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${e}`);
24
24
  }
25
25
  }
26
- function K(t) {
26
+ function H(t) {
27
27
  return t.map((e) => N(e));
28
28
  }
29
29
  /**
@@ -127,12 +127,12 @@ class C {
127
127
  * =============================================================================
128
128
  */
129
129
  class _ {
130
- constructor(e, o, u, p = !1) {
130
+ constructor(e, o, u, d = !1) {
131
131
  this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(o, u);
132
132
  const a = this.outputShape.length;
133
133
  this.enableShapeUniforms = g(a);
134
134
  let n = "";
135
- if (p)
135
+ if (d)
136
136
  if (a === 0 || R(this.outputShape) === 1)
137
137
  n = `
138
138
  result.y = 0.;
@@ -225,7 +225,7 @@ function A(t) {
225
225
  * =============================================================================
226
226
  */
227
227
  function G(t) {
228
- const { inputs: e, backend: o } = t, { real: u, imag: p } = e, a = o.makeTensorInfo(u.shape, "complex64"), n = o.texData.get(a.dataId), l = A({ inputs: { x: u }, backend: o }), s = A({ inputs: { x: p }, backend: o });
228
+ const { inputs: e, backend: o } = t, { real: u, imag: d } = e, a = o.makeTensorInfo(u.shape, "complex64"), n = o.texData.get(a.dataId), l = A({ inputs: { x: u }, backend: o }), s = A({ inputs: { x: d }, backend: o });
229
229
  return n.complexTensorInfos = { real: l, imag: s }, a;
230
230
  }
231
231
  /**
@@ -260,7 +260,7 @@ class V {
260
260
  `;
261
261
  }
262
262
  }
263
- const H = "if (isnan(x)) return x;";
263
+ const K = "if (isnan(x)) return x;";
264
264
  /**
265
265
  * @license
266
266
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -310,8 +310,8 @@ class L {
310
310
  * =============================================================================
311
311
  */
312
312
  function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
313
- return ({ inputs: p, backend: a }) => {
314
- const { x: n } = p, l = a, s = u || n.dtype;
313
+ return ({ inputs: d, backend: a }) => {
314
+ const { x: n } = d, l = a, s = u || n.dtype;
315
315
  if (l.shouldExecuteOnCPU([n]) && o != null) {
316
316
  const c = l.texData.get(n.dataId), x = o(c.values, s);
317
317
  return l.makeTensorInfo(n.shape, s, x);
@@ -321,37 +321,37 @@ function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
321
321
  return i ? r = new L(n.shape, e) : r = new V(n.shape, t), l.runWebGLProgram(r, [n], s);
322
322
  };
323
323
  }
324
- function j({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, supportsComplex: u = !1, cpuKernelImpl: p, dtype: a }) {
324
+ function q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, supportsComplex: u = !1, cpuKernelImpl: d, dtype: a }) {
325
325
  return ({ inputs: n, backend: l }) => {
326
326
  const { a: s, b: i } = n, r = l;
327
327
  if (u && s.dtype === "complex64") {
328
- const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [O, y] = [
328
+ const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [y, O] = [
329
329
  [h.complexTensorInfos.real, f.complexTensorInfos.real],
330
330
  [h.complexTensorInfos.imag, f.complexTensorInfos.imag]
331
331
  ].map((S) => {
332
- const [d, m] = S, $ = {
333
- dataId: d.dataId,
334
- dtype: d.dtype,
332
+ const [p, m] = S, $ = {
333
+ dataId: p.dataId,
334
+ dtype: p.dtype,
335
335
  shape: s.shape
336
336
  }, T = {
337
337
  dataId: m.dataId,
338
338
  dtype: m.dtype,
339
339
  shape: i.shape
340
340
  }, U = new C(t, s.shape, i.shape);
341
- return r.runWebGLProgram(U, [$, T], v(d.dtype, m.dtype));
342
- }), I = G({ inputs: { real: O, imag: y }, backend: r });
343
- return r.disposeIntermediateTensorInfo(O), r.disposeIntermediateTensorInfo(y), I;
341
+ return r.runWebGLProgram(U, [$, T], v(p.dtype, m.dtype));
342
+ }), I = G({ inputs: { real: y, imag: O }, backend: r });
343
+ return r.disposeIntermediateTensorInfo(y), r.disposeIntermediateTensorInfo(O), I;
344
344
  }
345
345
  const c = a || v(s.dtype, i.dtype);
346
- if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && p != null) {
347
- const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, O = s.dtype === "string" ? (
346
+ if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && d != null) {
347
+ const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, y = s.dtype === "string" ? (
348
348
  // tslint:disable-next-line: no-any
349
349
  B(h)
350
- ) : h, y = s.dtype === "string" ? (
350
+ ) : h, O = s.dtype === "string" ? (
351
351
  // tslint:disable-next-line: no-any
352
352
  B(f)
353
- ) : f, [I, S] = p(s.shape, i.shape, O, y, c), d = r.makeTensorInfo(S, c), m = r.texData.get(d.dataId);
354
- return m.values = I, d;
353
+ ) : f, [I, S] = d(s.shape, i.shape, y, O, c), p = r.makeTensorInfo(S, c), m = r.texData.get(p.dataId);
354
+ return m.values = I, p;
355
355
  }
356
356
  const x = P().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
357
357
  let b;
@@ -359,10 +359,10 @@ function j({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, support
359
359
  };
360
360
  }
361
361
  export {
362
- H as C,
363
- K as a,
362
+ K as C,
363
+ H as a,
364
364
  E as b,
365
- j as c,
365
+ q as c,
366
366
  B as f,
367
367
  k as g,
368
368
  Y as u
@@ -1,5 +1,116 @@
1
- import "../index-iNhkcAEQ.js";
2
- import { B as a } from "../BaseLayer-BhrMN8JO.js";
1
+ import { T as g, c as p, e as o, i as v } from "../index-CnHyhpKc.js";
2
+ import { v as _ } from "../variable-BGvK-VN3.js";
3
+ class M {
4
+ parent;
5
+ config;
6
+ _variables = /* @__PURE__ */ new Map();
7
+ _trainable = !0;
8
+ children = [];
9
+ constructor(t, r) {
10
+ this.config = t, this.parent = r, this.parent && this.parent.children.push(this);
11
+ }
12
+ getProfiler() {
13
+ return this.config.layerConfig.profiler;
14
+ }
15
+ startMemory() {
16
+ this.config.layerConfig.profiler?.startMemory();
17
+ }
18
+ endMemory(t) {
19
+ this.config.layerConfig.profiler?.endMemory(t);
20
+ }
21
+ addVariable(t, r) {
22
+ this._variables.set(t, r || null);
23
+ }
24
+ get variables() {
25
+ const t = Array.from(this._variables.values()).filter((e) => e !== null), r = this.children.flatMap((e) => e.variables);
26
+ return [...t, ...r];
27
+ }
28
+ get trainableVariables() {
29
+ const t = Array.from(this._variables.values()).filter(
30
+ (e) => e !== null && e.trainable
31
+ ), r = this.children.flatMap((e) => e.trainableVariables);
32
+ return [...t, ...r];
33
+ }
34
+ get trainable() {
35
+ return this._trainable;
36
+ }
37
+ set trainable(t) {
38
+ this._trainable = t, this._variables.forEach((r) => {
39
+ r && (r.trainable = t);
40
+ }), this.children.forEach((r) => {
41
+ r.trainable = t;
42
+ });
43
+ }
44
+ getVariable(t) {
45
+ const r = this._variables.get(t);
46
+ if (!r)
47
+ throw new Error(`Variable ${t} not found`);
48
+ return r;
49
+ }
50
+ hasVariable(t) {
51
+ return this._variables.get(t) !== null;
52
+ }
53
+ setVariable(t, r) {
54
+ if (!this._variables.has(t))
55
+ throw new Error(`Variable ${t} not found`);
56
+ this._variables.set(t, r);
57
+ }
58
+ saveWeights(t) {
59
+ this._variables.forEach((r, e) => {
60
+ r && t.set(e, [r.clone()]);
61
+ }), this.children.forEach((r) => {
62
+ r.saveWeights(t);
63
+ });
64
+ }
65
+ loadWeights(t) {
66
+ this._variables.forEach((r, e) => {
67
+ const i = t.get(e)?.[0];
68
+ if (!i)
69
+ throw new Error(`Weights for ${e} not found`);
70
+ r ? r.assign(i) : this._variables.set(e, _(i, this._trainable));
71
+ }), this.children.forEach((r) => {
72
+ r.loadWeights(t);
73
+ });
74
+ }
75
+ dispose() {
76
+ this._variables.forEach((t) => {
77
+ t?.dispose();
78
+ }), this._variables.clear();
79
+ }
80
+ build() {
81
+ }
82
+ dropout(t) {
83
+ return t;
84
+ }
85
+ call(t, ...r) {
86
+ this.build();
87
+ const e = this.forward(t, ...r);
88
+ if (t.training && e instanceof g) {
89
+ const i = this.dropout(e);
90
+ return i !== e && e.dispose(), i;
91
+ } else
92
+ return e;
93
+ }
94
+ callCheckpoint(t, ...r) {
95
+ return this.build(), this.checkpointingFn(t, ...r);
96
+ }
97
+ checkpointingFn(t, ...r) {
98
+ const e = this.trainableVariables, s = p((...a) => {
99
+ const l = a[a.length - 1], n = a.slice(0, r.length), h = this.forward(t, ...n);
100
+ return l(n), { value: h, gradFunc: (c, f) => {
101
+ const u = o().state.activeTape;
102
+ o().state.activeTape = [];
103
+ const b = v((...d) => this.forward(t, ...d.slice(0, n.length)))([...f, ...e], c);
104
+ return o().state.activeTape = u, b;
105
+ } };
106
+ })(...r, ...e);
107
+ if (t.training) {
108
+ const a = this.dropout(s);
109
+ return a !== s && s.dispose(), a;
110
+ } else
111
+ return s;
112
+ }
113
+ }
3
114
  export {
4
- a as default
115
+ M as default
5
116
  };
@@ -7,9 +7,8 @@ export type KVCache = {
7
7
  cumulativeLength: number;
8
8
  };
9
9
  export interface AttentionScores {
10
- head: number;
11
- block: number;
12
- attentionOut?: Tensor;
10
+ meanOfHeads?: boolean;
11
+ attentionOut?: Tensor[];
13
12
  }
14
13
  interface AttentionForwardAttributes extends ForwardAttributes {
15
14
  attentionScores?: AttentionScores;
@@ -1,15 +1,16 @@
1
- import { attentionMask as f } from "../ops/attentionMask.js";
2
- import { B as O, v as V } from "../BaseLayer-BhrMN8JO.js";
1
+ import { attentionMask as g } from "../ops/attentionMask.js";
2
+ import O from "./BaseLayer.js";
3
3
  import { qkv as P } from "../ops/qkv.js";
4
- import { rope as b } from "../ops/rope.js";
5
- import { appendCache as v } from "../ops/appendCache.js";
6
- import { F as c, t as C } from "../index-iNhkcAEQ.js";
4
+ import { rope as v } from "../ops/rope.js";
5
+ import { appendCache as V } from "../ops/appendCache.js";
6
+ import { H as c, t as C } from "../index-CnHyhpKc.js";
7
7
  import { fusedSoftmax as T } from "../ops/fusedSoftmax.js";
8
- import { d as y } from "../tfjs_backend-NucKez4s.js";
9
- import { r as k, d as L } from "../dropout-kbDY39Ci.js";
10
- import { r as N } from "../reshape-DxTPgnwL.js";
11
- import { m as R } from "../mat_mul-D0SifYfJ.js";
12
- class W extends O {
8
+ import { d as y } from "../tfjs_backend-DX9yVvwk.js";
9
+ import { v as b } from "../variable-BGvK-VN3.js";
10
+ import { r as k, d as L } from "../dropout-lQm_YyX3.js";
11
+ import { r as N } from "../reshape-CTIbqjwm.js";
12
+ import { m as R } from "../mat_mul-DeGU1U_C.js";
13
+ class $ extends O {
13
14
  divisor;
14
15
  index;
15
16
  units;
@@ -22,14 +23,14 @@ class W extends O {
22
23
  build() {
23
24
  this.hasVariable(this.ATTN) === !1 && this.setVariable(
24
25
  this.ATTN,
25
- V(
26
+ b(
26
27
  k([this.config.gpt.nEmbed, this.units], 0, 0.02),
27
28
  !0
28
29
  //`block_${this.index}_attn_cAttn_kernel`
29
30
  )
30
31
  ), this.hasVariable(this.PROJ) === !1 && this.setVariable(
31
32
  this.PROJ,
32
- V(
33
+ b(
33
34
  k([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
34
35
  !0
35
36
  //`block_${this.index}_attn_cProj_kernel`
@@ -37,12 +38,12 @@ class W extends O {
37
38
  );
38
39
  }
39
40
  getAttentionScores(t, i, s, o) {
40
- const e = f(t, i, this.divisor), n = T(e, s ? this.config.gpt.dropout : 0, o);
41
+ const e = g(t, i, this.divisor), n = T(e, s ? this.config.gpt.dropout : 0, o);
41
42
  return e.dispose(), n;
42
43
  }
43
44
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
44
45
  getAttentionScoresWithPast(t, i, s) {
45
- const o = f(t, i, this.divisor, s), e = T(o, 0, 0);
46
+ const o = g(t, i, this.divisor, s), e = T(o, 0, 0);
46
47
  return o.dispose(), e;
47
48
  }
48
49
  getQKV(t) {
@@ -53,33 +54,33 @@ class W extends O {
53
54
  return n.dispose(), e.dispose(), p;
54
55
  }
55
56
  updateCache(t, i, s) {
56
- const o = this.config.gpt.blockSize, e = t.shape[2], n = s.length || 0, p = v(t, o, n, s.k);
57
+ const o = this.config.gpt.blockSize, e = t.shape[2], n = s.length || 0, p = V(t, o, n, s.k);
57
58
  t.dispose(), s.k && s.k.dispose();
58
- const r = v(i, o, n, s.v);
59
+ const a = V(i, o, n, s.v);
59
60
  i.dispose(), s.v && s.v.dispose();
60
61
  const d = Math.min(n + e, o), h = s.cumulativeLength + e;
61
- s.length = d, s.cumulativeLength = h, s.k = c(p), s.v = c(r);
62
+ s.length = d, s.cumulativeLength = h, s.k = c(p), s.v = c(a);
62
63
  }
63
64
  forward(t, i) {
64
65
  return C(() => {
65
66
  this.startMemory();
66
- const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, p = this.config.layerConfig.ropeCache, r = p ? b(s, p, n) : s, d = p ? b(o, p, n) : o;
67
+ const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, p = this.config.layerConfig.ropeCache, a = p ? v(s, p, n) : s, d = p ? v(o, p, n) : o;
67
68
  p && (s.dispose(), o.dispose());
68
69
  const h = t.pastKV ? t.pastKV.length : 0;
69
70
  t.pastKV && !t.training && this.updateCache(d, e, t.pastKV);
70
- const u = t.pastKV?.k ? t.pastKV.k : d, l = t.pastKV?.v ? t.pastKV.v : e;
71
- let a;
72
- h > 0 ? a = this.getAttentionScoresWithPast(r, u, h) : a = this.getAttentionScores(r, u, t.training, t.seed || 0), r.dispose(), t.pastKV || u.dispose();
73
- const m = R(a, l), g = t.attentionScores !== void 0 && t.attentionScores.block === this.index;
74
- g || a.dispose(), t.pastKV || l.dispose();
75
- const S = this.getOutputProjection(m);
76
- if (m.dispose(), g && t.attentionScores && t.attentionScores.head >= 0 && t.attentionScores.head < this.config.gpt.nHead) {
77
- const A = a.shape[0], K = a.shape[2];
78
- t.attentionScores.attentionOut = c(
79
- a.slice([0, t.attentionScores.head, 0, 0], [-1, 1, -1, -1]).reshape([A, K, -1])
71
+ const u = t.pastKV?.k ? t.pastKV.k : d, m = t.pastKV?.v ? t.pastKV.v : e;
72
+ let r;
73
+ h > 0 ? r = this.getAttentionScoresWithPast(a, u, h) : r = this.getAttentionScores(a, u, t.training, t.seed || 0), a.dispose(), t.pastKV || u.dispose();
74
+ const l = R(r, m), f = t.attentionScores !== void 0 && t.attentionScores.attentionOut !== void 0;
75
+ f || r.dispose(), t.pastKV || m.dispose();
76
+ const A = this.getOutputProjection(l);
77
+ if (l.dispose(), f && t.attentionScores && t.attentionScores.attentionOut !== void 0) {
78
+ const K = r.shape[1], S = r.shape[2];
79
+ t.attentionScores.attentionOut?.push(
80
+ c(r.slice([0, 0, 0, 0], [1, -1, -1, -1]).reshape([K, S, -1]))
80
81
  );
81
82
  }
82
- return this.endMemory("CausalSelfAttention"), S;
83
+ return this.endMemory("CausalSelfAttention"), A;
83
84
  });
84
85
  }
85
86
  dropout(t) {
@@ -91,5 +92,5 @@ class W extends O {
91
92
  }
92
93
  }
93
94
  export {
94
- W as default
95
+ $ as default
95
96
  };
@@ -1,10 +1,11 @@
1
- import { t as l } from "../index-iNhkcAEQ.js";
2
- import { B as u, v as o } from "../BaseLayer-BhrMN8JO.js";
1
+ import { t as l } from "../index-CnHyhpKc.js";
2
+ import u from "./BaseLayer.js";
3
3
  import { matMulGelu as M } from "../ops/matMulGelu.js";
4
- import { r as h, d as c } from "../dropout-kbDY39Ci.js";
5
- import { r as d } from "../reshape-DxTPgnwL.js";
6
- import { m as f } from "../mat_mul-D0SifYfJ.js";
7
- class O extends u {
4
+ import { v as o } from "../variable-BGvK-VN3.js";
5
+ import { r as h, d as f } from "../dropout-lQm_YyX3.js";
6
+ import { r as d } from "../reshape-CTIbqjwm.js";
7
+ import { m as c } from "../mat_mul-DeGU1U_C.js";
8
+ class V extends u {
8
9
  index;
9
10
  hiddenUnits;
10
11
  MLPHIDDEN;
@@ -36,7 +37,7 @@ class O extends u {
36
37
  forward(i, t) {
37
38
  return l(() => {
38
39
  this.startMemory();
39
- const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), p = f(a, this.getVariable(this.MLPOUT));
40
+ const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), p = c(a, this.getVariable(this.MLPOUT));
40
41
  a.dispose();
41
42
  const m = d(p, [s, r, e]);
42
43
  return this.endMemory("MLP"), m;
@@ -44,12 +45,12 @@ class O extends u {
44
45
  }
45
46
  dropout(i) {
46
47
  if (this.config.gpt.dropout > 0) {
47
- const t = c(i, this.config.gpt.dropout);
48
+ const t = f(i, this.config.gpt.dropout);
48
49
  return i.dispose(), t;
49
50
  }
50
51
  return i;
51
52
  }
52
53
  }
53
54
  export {
54
- O as default
55
+ V as default
55
56
  };
@@ -1,20 +1,21 @@
1
- import { t as e } from "../index-iNhkcAEQ.js";
2
- import { B as o, v as a } from "../BaseLayer-BhrMN8JO.js";
3
- import { normRMS as i } from "../ops/normRMS.js";
4
- import { o as M } from "../ones-BIeFnPHR.js";
5
- class l extends o {
1
+ import { t as s } from "../index-CnHyhpKc.js";
2
+ import e from "./BaseLayer.js";
3
+ import { normRMS as a } from "../ops/normRMS.js";
4
+ import { v as i } from "../variable-BGvK-VN3.js";
5
+ import { o as m } from "../ones-CDWGzVnm.js";
6
+ class f extends e {
6
7
  GAMMA;
7
- constructor(r, t = "", s) {
8
- super(r, s), this.GAMMA = t, this.addVariable(this.GAMMA, a(M([r.gpt.nEmbed]), !0, this.GAMMA, "float32"));
8
+ constructor(r, t = "", o) {
9
+ super(r, o), this.GAMMA = t, this.addVariable(this.GAMMA, i(m([r.gpt.nEmbed]), !0, this.GAMMA, "float32"));
9
10
  }
10
11
  forward(r, t) {
11
- return e(() => {
12
+ return s(() => {
12
13
  this.startMemory();
13
- const s = i(t, this.getVariable(this.GAMMA));
14
- return this.endMemory("RMSNorm"), s;
14
+ const o = a(t, this.getVariable(this.GAMMA));
15
+ return this.endMemory("RMSNorm"), o;
15
16
  });
16
17
  }
17
18
  }
18
19
  export {
19
- l as default
20
+ f as default
20
21
  };
@@ -1,6 +1,6 @@
1
- import { o as c, i as f, E as l, Q as m, f as n, U as u, t as p, F as a } from "../index-iNhkcAEQ.js";
2
- import { c as d, s as C } from "../sin-BOX-JVAj.js";
3
- import { r as h } from "../range-BsFU-SNG.js";
1
+ import { o as c, j as f, E as l, V as m, f as n, W as u, t as p, H as a } from "../index-CnHyhpKc.js";
2
+ import { c as d, s as C } from "../sin-HzioENy_.js";
3
+ import { r as h } from "../range-CkOJ7090.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,8 +1,10 @@
1
- import { T as a } from "../TiedEmbedding-DsDRvLB0.js";
2
- import "../index-iNhkcAEQ.js";
3
- import "../tfjs_backend-NucKez4s.js";
4
- import "../BaseLayer-BhrMN8JO.js";
5
- import "../gather-Bxe1Qip8.js";
1
+ import "../random_width-DI2h9CMs.js";
2
+ import "../index-CnHyhpKc.js";
3
+ import { T as f } from "../TiedEmbedding-DORsPlNL.js";
4
+ import "../tfjs_backend-DX9yVvwk.js";
5
+ import "./BaseLayer.js";
6
+ import "../variable-BGvK-VN3.js";
7
+ import "../gather-BWyutxwi.js";
6
8
  export {
7
- a as default
9
+ f as default
8
10
  };
@@ -1,8 +1,8 @@
1
1
  import l from "./CausalSelfAttention.js";
2
2
  import r from "./MLP.js";
3
3
  import o from "./RMSNorm.js";
4
- import { B as d } from "../BaseLayer-BhrMN8JO.js";
5
- import { t as p } from "../index-iNhkcAEQ.js";
4
+ import d from "./BaseLayer.js";
5
+ import { t as p } from "../index-CnHyhpKc.js";
6
6
  class k extends d {
7
7
  ln1;
8
8
  attn;