@genai-fi/nanogpt 0.4.4 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +44 -41
  3. package/dist/NanoGPTModel.d.ts +12 -16
  4. package/dist/NanoGPTModel.js +128 -138
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +8 -5
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -127
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -11
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +22 -19
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.d.ts +1 -0
  49. package/dist/ops/cpu/normRMS.js +39 -0
  50. package/dist/ops/cpu/qkv.js +3 -3
  51. package/dist/ops/cpu/rope.js +5 -5
  52. package/dist/ops/cpu/scatterSub.js +8 -8
  53. package/dist/ops/fusedSoftmax.js +1 -1
  54. package/dist/ops/gatherSub.js +1 -1
  55. package/dist/ops/gelu.js +1 -1
  56. package/dist/ops/grads/attentionMask.js +13 -9
  57. package/dist/ops/grads/fusedSoftmax.js +12 -9
  58. package/dist/ops/grads/gelu.js +1 -1
  59. package/dist/ops/grads/matMulGelu.js +1 -1
  60. package/dist/ops/grads/normRMS.d.ts +2 -0
  61. package/dist/ops/grads/normRMS.js +20 -0
  62. package/dist/ops/grads/qkv.js +19 -9
  63. package/dist/ops/grads/rope.js +1 -1
  64. package/dist/ops/matMulGelu.js +1 -1
  65. package/dist/ops/matMulMul.d.ts +2 -0
  66. package/dist/ops/matMulMul.js +9 -0
  67. package/dist/ops/mulDrop.js +1 -1
  68. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  69. package/dist/ops/normRMS.d.ts +2 -0
  70. package/dist/ops/normRMS.js +10 -0
  71. package/dist/ops/qkv.js +1 -1
  72. package/dist/ops/scatterSub.js +1 -1
  73. package/dist/ops/webgl/appendCache.js +1 -1
  74. package/dist/ops/webgl/attentionMask.js +13 -12
  75. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  76. package/dist/ops/webgl/gatherSub.js +1 -1
  77. package/dist/ops/webgl/gelu.js +2 -2
  78. package/dist/ops/webgl/matMulGelu.d.ts +3 -2
  79. package/dist/ops/webgl/matMulGelu.js +77 -75
  80. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  81. package/dist/ops/webgl/matMulMul.js +28 -0
  82. package/dist/ops/webgl/mulDropout.js +1 -1
  83. package/dist/ops/webgl/normRMS.d.ts +1 -0
  84. package/dist/ops/webgl/normRMS.js +86 -0
  85. package/dist/ops/webgl/qkv.js +1 -1
  86. package/dist/ops/webgl/rope.js +1 -1
  87. package/dist/ops/webgl/scatterSub.js +1 -1
  88. package/dist/ops-ObfXLHYQ.js +1269 -0
  89. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  90. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  91. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  92. package/dist/slice_util-D-kaD4ZV.js +49 -0
  93. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  94. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  95. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  96. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  97. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  98. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  99. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  100. package/dist/training/AdamExt.js +1 -1
  101. package/dist/training/DatasetBuilder.js +44 -44
  102. package/dist/training/Evaluator.js +6 -6
  103. package/dist/training/FullTrainer.js +1 -1
  104. package/dist/training/Trainer.js +7 -7
  105. package/dist/training/sparseCrossEntropy.js +4 -4
  106. package/dist/utilities/dummy.js +10 -10
  107. package/dist/utilities/generate.js +3 -3
  108. package/dist/utilities/load.js +1 -1
  109. package/dist/utilities/profile.js +1 -1
  110. package/dist/utilities/save.js +10 -8
  111. package/dist/utilities/weights.js +2 -2
  112. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  113. package/package.json +1 -1
  114. package/dist/slice_util-BdhYwFY_.js +0 -90
  115. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  116. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,4 +1,4 @@
1
- import { j as c } from "./index--6vO-cOz.js";
1
+ import { k as c } from "./index-iNhkcAEQ.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -0,0 +1,44 @@
1
+ import { o as h, i as f, l as p, x as g, E as u, T } from "./index-iNhkcAEQ.js";
2
+ import { r as b } from "./reshape-DxTPgnwL.js";
3
+ /**
4
+ * @license
5
+ * Copyright 2020 Google LLC. All Rights Reserved.
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ * =============================================================================
18
+ */
19
+ function m(e, r) {
20
+ let n = f(e, "broadcastTo", "x");
21
+ const a = n.shape;
22
+ if (p(r), r.length < n.rank)
23
+ throw new Error(`broadcastTo(): shape.length=${r.length} < input.rank=${n.rank}.`);
24
+ if (r.length > n.rank) {
25
+ const t = n.shape.slice();
26
+ for (; t.length < r.length; )
27
+ t.unshift(1);
28
+ n = b(n, t);
29
+ }
30
+ const s = n.shape, o = Array.from(r);
31
+ for (let t = r.length - 1; t >= 0; t--)
32
+ if (s[t] === r[t])
33
+ o[t] = 1;
34
+ else if (n.shape[t] !== 1)
35
+ throw new Error(`broadcastTo(): [${a}] cannot be broadcast to [${r}].`);
36
+ if (o.map((t, l) => t > 1 ? l : -1).filter((t) => t >= 0).length === 0)
37
+ return g(n);
38
+ const i = { x: n }, c = { reps: o };
39
+ return u.runKernel(T, i, c);
40
+ }
41
+ const E = /* @__PURE__ */ h({ broadcastTo_: m });
42
+ export {
43
+ E as b
44
+ };
@@ -1,4 +1,4 @@
1
- import { o as s, j as a, i, w as p, E as l, C as f } from "./index--6vO-cOz.js";
1
+ import { o as s, k as a, j as p, x as i, E as l, C as f } from "./index-iNhkcAEQ.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -17,13 +17,13 @@ import { o as s, j as a, i, w as p, E as l, C as f } from "./index--6vO-cOz.js";
17
17
  */
18
18
  function h(o, e = 0) {
19
19
  a(o.length >= 1, () => "Pass at least one tensor to concat");
20
- const t = i(o, "tensors", "concat", "string_or_numeric");
20
+ const t = p(o, "tensors", "concat", "string_or_numeric");
21
21
  if (t[0].dtype === "complex64" && t.forEach((n) => {
22
22
  if (n.dtype !== "complex64")
23
23
  throw new Error(`Cannot concatenate complex64 tensors with a tensor
24
24
  with dtype ${n.dtype}. `);
25
25
  }), t.length === 1)
26
- return p(t[0]);
26
+ return i(t[0]);
27
27
  const r = t, c = { axis: e };
28
28
  return l.runKernel(f, r, c);
29
29
  }
@@ -1,4 +1,4 @@
1
- import { o as l, h, E as m, af as p, k as c, ag as d, ad as g, j as u, ah as V, ai as v, a8 as N, b as w } from "./index--6vO-cOz.js";
1
+ import { o as l, i as h, E as m, ag as p, l as c, ah as d, ae as g, k as u, V, ai as v, a9 as N, b as w } from "./index-iNhkcAEQ.js";
2
2
  import { s as f } from "./index-C4L8Cm77.js";
3
3
  /**
4
4
  * @license
@@ -1,4 +1,4 @@
1
- import { o as h, h as t, E as g, G as p } from "./index--6vO-cOz.js";
1
+ import { o as g, i as t, E as h, G as p } from "./index-iNhkcAEQ.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -16,10 +16,10 @@ import { o as h, h as t, E as g, G as p } from "./index--6vO-cOz.js";
16
16
  * =============================================================================
17
17
  */
18
18
  function u(n, s, r = 0, e = 0) {
19
- const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"), c = { x: o, indices: a }, i = { axis: r, batchDims: e };
20
- return g.runKernel(p, c, i);
19
+ const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"), i = { x: o, indices: a }, c = { axis: r, batchDims: e };
20
+ return h.runKernel(p, i, c);
21
21
  }
22
- const d = /* @__PURE__ */ h({ gather_: u });
22
+ const d = /* @__PURE__ */ g({ gather_: u });
23
23
  export {
24
24
  d as g
25
25
  };
@@ -1,4 +1,4 @@
1
- import { L as e } from "./index--6vO-cOz.js";
1
+ import { K as e } from "./index-iNhkcAEQ.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -4001,81 +4001,81 @@ function As() {
4001
4001
  */
4002
4002
  As();
4003
4003
  export {
4004
- xe as $,
4004
+ Oa as $,
4005
4005
  Ss as A,
4006
4006
  Zs as B,
4007
4007
  or as C,
4008
- ss as D,
4008
+ Wa as D,
4009
4009
  g as E,
4010
4010
  Bn as F,
4011
4011
  Pr as G,
4012
- To as H,
4013
- Fs as I,
4014
- kn as J,
4015
- En as K,
4016
- k as L,
4012
+ Fs as H,
4013
+ kn as I,
4014
+ En as J,
4015
+ k as K,
4016
+ Lr as L,
4017
4017
  ta as M,
4018
- Lr as N,
4019
- rs as O,
4018
+ rs as N,
4019
+ de as O,
4020
4020
  ba as P,
4021
- de as Q,
4021
+ Ea as Q,
4022
4022
  Ia as R,
4023
4023
  qa as S,
4024
- Ea as T,
4024
+ Qa as T,
4025
4025
  Zt as U,
4026
- dr as V,
4027
- Oa as W,
4026
+ D as V,
4027
+ To as W,
4028
4028
  De as X,
4029
4029
  ar as Y,
4030
4030
  ne as Z,
4031
- aa as _,
4031
+ dr as _,
4032
4032
  M as a,
4033
4033
  xs as a$,
4034
- V as a0,
4035
- oa as a1,
4036
- ns as a2,
4037
- nt as a3,
4038
- Qa as a4,
4039
- Ca as a5,
4040
- Fr as a6,
4041
- qr as a7,
4042
- S as a8,
4043
- la as a9,
4044
- Wr as aA,
4045
- jr as aB,
4046
- Kr as aC,
4047
- ha as aD,
4048
- Jr as aE,
4049
- ia as aF,
4050
- Sa as aG,
4051
- Ta as aH,
4052
- Aa as aI,
4053
- Ra as aJ,
4054
- $a as aK,
4055
- Ds as aL,
4056
- ro as aM,
4057
- no as aN,
4058
- eo as aO,
4059
- Io as aP,
4060
- oo as aQ,
4061
- yr as aR,
4062
- $r as aS,
4063
- ao as aT,
4064
- da as aU,
4065
- ma as aV,
4066
- ga as aW,
4067
- Na as aX,
4068
- va as aY,
4069
- to as aZ,
4070
- yo as a_,
4071
- ua as aa,
4072
- Za as ab,
4073
- $t as ac,
4074
- Rt as ad,
4075
- Rs as ae,
4076
- xr as af,
4077
- Wn as ag,
4078
- D as ah,
4034
+ aa as a0,
4035
+ xe as a1,
4036
+ V as a2,
4037
+ oa as a3,
4038
+ ns as a4,
4039
+ nt as a5,
4040
+ Ca as a6,
4041
+ Fr as a7,
4042
+ qr as a8,
4043
+ S as a9,
4044
+ _a as aA,
4045
+ er as aB,
4046
+ Pa as aC,
4047
+ Ar as aD,
4048
+ Rr as aE,
4049
+ _r as aF,
4050
+ Or as aG,
4051
+ Gr as aH,
4052
+ jr as aI,
4053
+ Kr as aJ,
4054
+ ha as aK,
4055
+ Jr as aL,
4056
+ ia as aM,
4057
+ Ta as aN,
4058
+ $a as aO,
4059
+ Ds as aP,
4060
+ no as aQ,
4061
+ eo as aR,
4062
+ yr as aS,
4063
+ $r as aT,
4064
+ ao as aU,
4065
+ da as aV,
4066
+ ma as aW,
4067
+ ga as aX,
4068
+ Na as aY,
4069
+ va as aZ,
4070
+ to as a_,
4071
+ la as aa,
4072
+ ua as ab,
4073
+ Za as ac,
4074
+ $t as ad,
4075
+ Rt as ae,
4076
+ Rs as af,
4077
+ xr as ag,
4078
+ Wn as ah,
4079
4079
  x as ai,
4080
4080
  F as aj,
4081
4081
  pe as ak,
@@ -4084,16 +4084,16 @@ export {
4084
4084
  jt as an,
4085
4085
  ue as ao,
4086
4086
  za as ap,
4087
- _a as aq,
4088
- er as ar,
4089
- rr as as,
4090
- Pa as at,
4091
- Ar as au,
4092
- Br as av,
4093
- Rr as aw,
4094
- _r as ax,
4095
- Or as ay,
4096
- Gr as az,
4087
+ rr as aq,
4088
+ Br as ar,
4089
+ Wr as as,
4090
+ Sa as at,
4091
+ Aa as au,
4092
+ Ra as av,
4093
+ ro as aw,
4094
+ Io as ax,
4095
+ oo as ay,
4096
+ yo as az,
4097
4097
  b,
4098
4098
  Vs as b$,
4099
4099
  $s as b0,
@@ -4212,24 +4212,24 @@ export {
4212
4212
  go as d,
4213
4213
  mo as e,
4214
4214
  K as f,
4215
- lo as g,
4216
- T as h,
4217
- In as i,
4218
- y as j,
4219
- xt as k,
4220
- Ge as l,
4215
+ ss as g,
4216
+ lo as h,
4217
+ T as i,
4218
+ In as j,
4219
+ y as k,
4220
+ xt as l,
4221
4221
  po as m,
4222
- z as n,
4222
+ Ge as n,
4223
4223
  N as o,
4224
- q as p,
4225
- Ba as q,
4224
+ z as p,
4225
+ q,
4226
4226
  co as r,
4227
4227
  tt as s,
4228
4228
  E as t,
4229
- Ka as u,
4229
+ Ba as u,
4230
4230
  ls as v,
4231
- qn as w,
4232
- Ft as x,
4233
- Wa as y,
4231
+ Ka as w,
4232
+ qn as x,
4233
+ Ft as y,
4234
4234
  C as z
4235
4235
  };
@@ -1,5 +1,5 @@
1
- import { an as D, ao as N, O as w, n as R, Q as v, L as P } from "./index--6vO-cOz.js";
2
- import { u as g } from "./gpgpu_math-CUzjlO9A.js";
1
+ import { an as D, ao as N, N as w, p as R, O as v, K as P } from "./index-iNhkcAEQ.js";
2
+ import { u as g } from "./gpgpu_math-C0zyxKFi.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -23,7 +23,7 @@ function B(t) {
23
23
  throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${e}`);
24
24
  }
25
25
  }
26
- function H(t) {
26
+ function K(t) {
27
27
  return t.map((e) => N(e));
28
28
  }
29
29
  /**
@@ -127,12 +127,12 @@ class C {
127
127
  * =============================================================================
128
128
  */
129
129
  class _ {
130
- constructor(e, o, u, d = !1) {
130
+ constructor(e, o, u, p = !1) {
131
131
  this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(o, u);
132
132
  const a = this.outputShape.length;
133
133
  this.enableShapeUniforms = g(a);
134
134
  let n = "";
135
- if (d)
135
+ if (p)
136
136
  if (a === 0 || R(this.outputShape) === 1)
137
137
  n = `
138
138
  result.y = 0.;
@@ -225,7 +225,7 @@ function A(t) {
225
225
  * =============================================================================
226
226
  */
227
227
  function G(t) {
228
- const { inputs: e, backend: o } = t, { real: u, imag: d } = e, a = o.makeTensorInfo(u.shape, "complex64"), n = o.texData.get(a.dataId), l = A({ inputs: { x: u }, backend: o }), s = A({ inputs: { x: d }, backend: o });
228
+ const { inputs: e, backend: o } = t, { real: u, imag: p } = e, a = o.makeTensorInfo(u.shape, "complex64"), n = o.texData.get(a.dataId), l = A({ inputs: { x: u }, backend: o }), s = A({ inputs: { x: p }, backend: o });
229
229
  return n.complexTensorInfos = { real: l, imag: s }, a;
230
230
  }
231
231
  /**
@@ -260,7 +260,7 @@ class V {
260
260
  `;
261
261
  }
262
262
  }
263
- const K = "if (isnan(x)) return x;";
263
+ const H = "if (isnan(x)) return x;";
264
264
  /**
265
265
  * @license
266
266
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -310,8 +310,8 @@ class L {
310
310
  * =============================================================================
311
311
  */
312
312
  function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
313
- return ({ inputs: d, backend: a }) => {
314
- const { x: n } = d, l = a, s = u || n.dtype;
313
+ return ({ inputs: p, backend: a }) => {
314
+ const { x: n } = p, l = a, s = u || n.dtype;
315
315
  if (l.shouldExecuteOnCPU([n]) && o != null) {
316
316
  const c = l.texData.get(n.dataId), x = o(c.values, s);
317
317
  return l.makeTensorInfo(n.shape, s, x);
@@ -321,7 +321,7 @@ function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
321
321
  return i ? r = new L(n.shape, e) : r = new V(n.shape, t), l.runWebGLProgram(r, [n], s);
322
322
  };
323
323
  }
324
- function Q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, supportsComplex: u = !1, cpuKernelImpl: d, dtype: a }) {
324
+ function j({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, supportsComplex: u = !1, cpuKernelImpl: p, dtype: a }) {
325
325
  return ({ inputs: n, backend: l }) => {
326
326
  const { a: s, b: i } = n, r = l;
327
327
  if (u && s.dtype === "complex64") {
@@ -329,29 +329,29 @@ function Q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, support
329
329
  [h.complexTensorInfos.real, f.complexTensorInfos.real],
330
330
  [h.complexTensorInfos.imag, f.complexTensorInfos.imag]
331
331
  ].map((S) => {
332
- const [p, m] = S, $ = {
333
- dataId: p.dataId,
334
- dtype: p.dtype,
332
+ const [d, m] = S, $ = {
333
+ dataId: d.dataId,
334
+ dtype: d.dtype,
335
335
  shape: s.shape
336
336
  }, T = {
337
337
  dataId: m.dataId,
338
338
  dtype: m.dtype,
339
339
  shape: i.shape
340
340
  }, U = new C(t, s.shape, i.shape);
341
- return r.runWebGLProgram(U, [$, T], v(p.dtype, m.dtype));
341
+ return r.runWebGLProgram(U, [$, T], v(d.dtype, m.dtype));
342
342
  }), I = G({ inputs: { real: O, imag: y }, backend: r });
343
343
  return r.disposeIntermediateTensorInfo(O), r.disposeIntermediateTensorInfo(y), I;
344
344
  }
345
345
  const c = a || v(s.dtype, i.dtype);
346
- if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && d != null) {
346
+ if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && p != null) {
347
347
  const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, O = s.dtype === "string" ? (
348
348
  // tslint:disable-next-line: no-any
349
349
  B(h)
350
350
  ) : h, y = s.dtype === "string" ? (
351
351
  // tslint:disable-next-line: no-any
352
352
  B(f)
353
- ) : f, [I, S] = d(s.shape, i.shape, O, y, c), p = r.makeTensorInfo(S, c), m = r.texData.get(p.dataId);
354
- return m.values = I, p;
353
+ ) : f, [I, S] = p(s.shape, i.shape, O, y, c), d = r.makeTensorInfo(S, c), m = r.texData.get(d.dataId);
354
+ return m.values = I, d;
355
355
  }
356
356
  const x = P().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
357
357
  let b;
@@ -359,10 +359,10 @@ function Q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, support
359
359
  };
360
360
  }
361
361
  export {
362
- K as C,
363
- H as a,
362
+ H as C,
363
+ K as a,
364
364
  E as b,
365
- Q as c,
365
+ j as c,
366
366
  B as f,
367
367
  k as g,
368
368
  Y as u
@@ -1,9 +1,9 @@
1
1
  import { GPTConfig } from '../config';
2
2
  import { default as MemoryProfiler } from '../utilities/profile';
3
3
  import { default as RoPECache } from './RoPECache';
4
+ import { Tensor, Variable } from '@tensorflow/tfjs-core';
4
5
  export interface LayerConfig {
5
- checkpointAttention?: boolean;
6
- checkpointMLP?: boolean;
6
+ checkpointing?: boolean;
7
7
  profiler?: MemoryProfiler;
8
8
  ropeCache?: RoPECache;
9
9
  }
@@ -11,10 +11,34 @@ export interface GPTLayerConfig {
11
11
  gpt: GPTConfig;
12
12
  layerConfig: LayerConfig;
13
13
  }
14
- export default abstract class BaseLayer {
14
+ export interface ForwardAttributes {
15
+ training: boolean;
16
+ }
17
+ export default abstract class BaseLayer<ATTR extends ForwardAttributes = ForwardAttributes> {
18
+ readonly parent?: BaseLayer;
15
19
  readonly config: GPTLayerConfig;
16
- constructor(config: GPTLayerConfig);
20
+ private _variables;
21
+ private _trainable;
22
+ readonly children: BaseLayer[];
23
+ constructor(config: GPTLayerConfig, parent?: BaseLayer);
17
24
  getProfiler(): MemoryProfiler | undefined;
18
25
  startMemory(): void;
19
26
  endMemory(label: string): void;
27
+ addVariable(name: string, variable?: Variable): void;
28
+ get variables(): Variable[];
29
+ get trainableVariables(): Variable[];
30
+ get trainable(): boolean;
31
+ set trainable(value: boolean);
32
+ getVariable(name: string): Variable;
33
+ hasVariable(name: string): boolean;
34
+ setVariable(name: string, variable: Variable): void;
35
+ saveWeights(map: Map<string, Tensor[]>): void;
36
+ loadWeights(weights: Map<string, Tensor[]>): void;
37
+ dispose(): void;
38
+ protected build(): void;
39
+ protected dropout(x: Tensor): Tensor;
40
+ abstract forward(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
41
+ call(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
42
+ callCheckpoint(attrs: ATTR, ...x: Tensor[]): Tensor;
43
+ private checkpointingFn;
20
44
  }
@@ -1,18 +1,5 @@
1
- class o {
2
- config;
3
- constructor(r) {
4
- this.config = r;
5
- }
6
- getProfiler() {
7
- return this.config.layerConfig.profiler;
8
- }
9
- startMemory() {
10
- this.config.layerConfig.profiler?.startMemory();
11
- }
12
- endMemory(r) {
13
- this.config.layerConfig.profiler?.endMemory(r);
14
- }
15
- }
1
+ import "../index-iNhkcAEQ.js";
2
+ import { B as a } from "../BaseLayer-BhrMN8JO.js";
16
3
  export {
17
- o as default
4
+ a as default
18
5
  };
@@ -1,38 +1,36 @@
1
- import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
2
- import { Tensor, Variable } from '@tensorflow/tfjs-core';
1
+ import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
2
+ import { Tensor } from '@tensorflow/tfjs-core';
3
3
  export type KVCache = {
4
- k: Tensor;
5
- v: Tensor;
4
+ k?: Tensor;
5
+ v?: Tensor;
6
6
  length: number;
7
7
  cumulativeLength: number;
8
8
  };
9
- export default class CausalSelfAttention extends BaseLayer {
10
- private cAttn;
11
- private cProj;
12
- private bias;
13
- private maskInf;
9
+ export interface AttentionScores {
10
+ head: number;
11
+ block: number;
12
+ attentionOut?: Tensor;
13
+ }
14
+ interface AttentionForwardAttributes extends ForwardAttributes {
15
+ attentionScores?: AttentionScores;
16
+ pastKV?: KVCache;
17
+ seed?: number;
18
+ }
19
+ export default class CausalSelfAttention extends BaseLayer<AttentionForwardAttributes> {
14
20
  private divisor;
15
21
  private index;
16
- private _trainable;
17
22
  private units;
18
23
  private projUnits;
19
- constructor(index: number, config: GPTLayerConfig);
20
- private build;
21
- get variables(): Variable[];
22
- get trainable(): boolean;
23
- set trainable(value: boolean);
24
- saveWeights(map: Map<string, Tensor[]>): void;
25
- loadWeights(weights: Map<string, Tensor[]>): void;
24
+ private ATTN;
25
+ private PROJ;
26
+ constructor(index: number, config: GPTLayerConfig, parent?: BaseLayer);
27
+ protected build(): void;
26
28
  private getAttentionScores;
27
29
  private getAttentionScoresWithPast;
28
30
  private getQKV;
29
31
  private getOutputProjection;
30
32
  private updateCache;
31
- private forward;
32
- call(x: Tensor, training?: boolean, includeAttention?: boolean, pastKV?: KVCache): {
33
- output: Tensor;
34
- attention?: Tensor;
35
- presentKV?: KVCache;
36
- };
37
- dispose(): void;
33
+ forward(attr: AttentionForwardAttributes, x: Tensor): Tensor;
34
+ protected dropout(x: Tensor): Tensor;
38
35
  }
36
+ export {};