@genai-fi/nanogpt 0.3.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/dist/Generator.js +22 -22
  2. package/dist/MLP-KHhikThU.js +83 -0
  3. package/dist/NanoGPTModel.d.ts +2 -3
  4. package/dist/NanoGPTModel.js +79 -79
  5. package/dist/TeachableLLM.d.ts +4 -3
  6. package/dist/TeachableLLM.js +16 -13
  7. package/dist/Trainer.js +20 -13
  8. package/dist/axis_util-DeydwOoC.js +69 -0
  9. package/dist/{concat-BIZS_td9.js → concat-DS_qH7MI.js} +5 -5
  10. package/dist/config.js +7 -8
  11. package/dist/{gather-BPGW8RsB.js → gather-BUmJIS8n.js} +1 -1
  12. package/dist/{index-pWA4_lUh.js → index-XjBAhiFO.js} +1272 -1174
  13. package/dist/layers/BaseLayer.d.ts +14 -2
  14. package/dist/layers/BaseLayer.js +9 -9
  15. package/dist/layers/CausalSelfAttention.d.ts +4 -8
  16. package/dist/layers/CausalSelfAttention.js +108 -82
  17. package/dist/layers/MLP.d.ts +2 -3
  18. package/dist/layers/MLP.js +5 -62
  19. package/dist/layers/RMSNorm.d.ts +2 -2
  20. package/dist/layers/RMSNorm.js +11 -11
  21. package/dist/layers/RoPECache.js +3 -3
  22. package/dist/layers/TiedEmbedding.js +7 -6
  23. package/dist/layers/TransformerBlock.d.ts +2 -6
  24. package/dist/layers/TransformerBlock.js +9 -12
  25. package/dist/{sum-C7Mgy9Bw.js → log_sum_exp-DJPkVZZn.js} +32 -54
  26. package/dist/main.js +22 -19
  27. package/dist/{mat_mul-D7_a4KJn.js → mat_mul-CKwFEV1Q.js} +1 -1
  28. package/dist/max-DJvEiCAJ.js +25 -0
  29. package/dist/moments-CrWRPcR3.js +53 -0
  30. package/dist/norm-BzY929B_.js +86 -0
  31. package/dist/{ones-Cog-G2ag.js → ones-BO01zpJG.js} +2 -2
  32. package/dist/ops/appendCache.js +1 -1
  33. package/dist/ops/attentionMask.js +1 -1
  34. package/dist/ops/cpu/appendCache.js +2 -2
  35. package/dist/ops/cpu/attentionMask.js +2 -2
  36. package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
  37. package/dist/ops/cpu/fusedSoftmax.js +23 -0
  38. package/dist/ops/cpu/gatherSub.js +3 -3
  39. package/dist/ops/cpu/mulDropout.d.ts +1 -0
  40. package/dist/ops/cpu/mulDropout.js +17 -0
  41. package/dist/ops/cpu/qkv.js +3 -3
  42. package/dist/ops/cpu/rope.js +5 -5
  43. package/dist/ops/cpu/scatterSub.js +27 -27
  44. package/dist/ops/fusedSoftmax.d.ts +2 -0
  45. package/dist/ops/fusedSoftmax.js +10 -0
  46. package/dist/ops/gatherSub.js +1 -1
  47. package/dist/ops/grads/attentionMask.js +1 -1
  48. package/dist/ops/grads/fusedSoftmax.d.ts +2 -0
  49. package/dist/ops/grads/fusedSoftmax.js +17 -0
  50. package/dist/ops/grads/qkv.js +1 -1
  51. package/dist/ops/grads/rope.js +1 -1
  52. package/dist/ops/mulDrop.d.ts +2 -0
  53. package/dist/ops/mulDrop.js +9 -0
  54. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  55. package/dist/ops/qkv.js +1 -1
  56. package/dist/ops/scatterSub.js +1 -1
  57. package/dist/ops/webgl/appendCache.js +1 -1
  58. package/dist/ops/webgl/attentionMask.js +1 -1
  59. package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
  60. package/dist/ops/webgl/fusedSoftmax.js +3930 -0
  61. package/dist/ops/webgl/gatherSub.js +1 -1
  62. package/dist/ops/webgl/mulDropout.d.ts +1 -0
  63. package/dist/ops/webgl/mulDropout.js +41 -0
  64. package/dist/ops/webgl/qkv.js +1 -1
  65. package/dist/ops/webgl/rope.js +1 -1
  66. package/dist/ops/webgl/scatterSub.js +1 -1
  67. package/dist/{random_width-oeUIlUZj.js → random_width-CMHmdbSu.js} +4212 -6630
  68. package/dist/{range-CcDl05lo.js → range-DQMNzBWs.js} +1 -1
  69. package/dist/{reshape-C8CR_Bad.js → reshape-DFzh97Sc.js} +1 -1
  70. package/dist/{sin-BJIrfnj7.js → sin-BYM-U4Ut.js} +1 -1
  71. package/dist/slice_util-CnVNPQI-.js +90 -0
  72. package/dist/softmax-4DOn6cPq.js +28 -0
  73. package/dist/{split-DZbvruEP.js → split-CkbeVdF8.js} +3 -3
  74. package/dist/{stack-BMm-efee.js → stack-DaIMO5iX.js} +1 -1
  75. package/dist/sum-C6u3xMi3.js +27 -0
  76. package/dist/{tensor-DJVbYhh1.js → tensor-Cu1fU7H7.js} +1 -1
  77. package/dist/{tensor2d-ZuQSh2D-.js → tensor2d-D0CKdG6B.js} +1 -1
  78. package/dist/tfjs_backend-Bzl2SrRo.js +2460 -0
  79. package/dist/training/AdamExt.js +1 -1
  80. package/dist/training/DatasetBuilder.js +3 -3
  81. package/dist/training/FullTrainer.js +41 -33
  82. package/dist/training/Trainer.d.ts +6 -1
  83. package/dist/training/Trainer.js +13 -12
  84. package/dist/training/sparseCrossEntropy.js +12 -11
  85. package/dist/utilities/dummy.js +8 -8
  86. package/dist/utilities/generate.js +11 -11
  87. package/dist/utilities/load.js +1 -1
  88. package/dist/utilities/profile.js +1 -1
  89. package/dist/utilities/weights.js +2 -2
  90. package/dist/{variable-Dl_ub3pk.js → variable-BS4AKqNU.js} +1 -1
  91. package/dist/{zeros-CCy9C3uU.js → zeros-CmJFiC84.js} +1 -1
  92. package/package.json +1 -1
  93. package/dist/exports_layers-tbTBcwMM.js +0 -25
  94. package/dist/layers/LayerNorm.d.ts +0 -13
  95. package/dist/layers/LayerNorm.js +0 -33
  96. package/dist/moments-DfcpfwKi.js +0 -132
  97. package/dist/softmax-Be_lsqUc.js +0 -105
  98. package/dist/training/LayerTrainer.d.ts +0 -29
  99. package/dist/training/LayerTrainer.js +0 -90
  100. package/dist/training/lwSchedule.d.ts +0 -7
  101. package/dist/training/lwSchedule.js +0 -162
@@ -1,8 +1,20 @@
1
+ import { GPTConfig } from '../config';
1
2
  import { default as MemoryProfiler } from '../utilities/profile';
3
+ import { default as RoPECache } from './RoPECache';
4
+ export interface LayerConfig {
5
+ checkpointAttention?: boolean;
6
+ checkpointMLP?: boolean;
7
+ profiler?: MemoryProfiler;
8
+ ropeCache?: RoPECache;
9
+ }
10
+ export interface GPTLayerConfig {
11
+ gpt: GPTConfig;
12
+ layerConfig: LayerConfig;
13
+ }
2
14
  export default abstract class BaseLayer {
3
- protected _profiler?: MemoryProfiler;
15
+ readonly config: GPTLayerConfig;
16
+ constructor(config: GPTLayerConfig);
4
17
  getProfiler(): MemoryProfiler | undefined;
5
- setProfiler(value: MemoryProfiler | undefined): void;
6
18
  startMemory(): void;
7
19
  endMemory(label: string): void;
8
20
  }
@@ -1,18 +1,18 @@
1
- class t {
2
- _profiler;
3
- getProfiler() {
4
- return this._profiler;
1
+ class o {
2
+ config;
3
+ constructor(r) {
4
+ this.config = r;
5
5
  }
6
- setProfiler(r) {
7
- this._profiler = r;
6
+ getProfiler() {
7
+ return this.config.layerConfig.profiler;
8
8
  }
9
9
  startMemory() {
10
- this._profiler?.startMemory();
10
+ this.config.layerConfig.profiler?.startMemory();
11
11
  }
12
12
  endMemory(r) {
13
- this._profiler?.endMemory(r);
13
+ this.config.layerConfig.profiler?.endMemory(r);
14
14
  }
15
15
  }
16
16
  export {
17
- t as default
17
+ o as default
18
18
  };
@@ -1,6 +1,4 @@
1
- import { GPTConfig } from '../config';
2
- import { default as RoPECache } from './RoPECache';
3
- import { default as BaseLayer } from './BaseLayer';
1
+ import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
4
2
  import { Tensor, Variable } from '@tensorflow/tfjs-core';
5
3
  export type KVCache = {
6
4
  k: Tensor;
@@ -9,19 +7,16 @@ export type KVCache = {
9
7
  cumulativeLength: number;
10
8
  };
11
9
  export default class CausalSelfAttention extends BaseLayer {
12
- private readonly ropeCache?;
13
- private config;
14
10
  private cAttn;
15
11
  private cProj;
16
- private attnDropout;
17
- private residDropout;
18
12
  private bias;
19
13
  private maskInf;
20
14
  private divisor;
21
15
  private index;
22
16
  private _trainable;
23
17
  private units;
24
- constructor(index: number, config: GPTConfig, ropeCache?: RoPECache | undefined);
18
+ private projUnits;
19
+ constructor(index: number, config: GPTLayerConfig);
25
20
  private build;
26
21
  get variables(): Variable[];
27
22
  get trainable(): boolean;
@@ -33,6 +28,7 @@ export default class CausalSelfAttention extends BaseLayer {
33
28
  private getQKV;
34
29
  private getOutputProjection;
35
30
  private updateCache;
31
+ private forward;
36
32
  call(x: Tensor, training?: boolean, includeAttention?: boolean, pastKV?: KVCache): {
37
33
  output: Tensor;
38
34
  attention?: Tensor;
@@ -1,122 +1,148 @@
1
- import { attentionMask as C } from "../ops/attentionMask.js";
2
- import x from "./BaseLayer.js";
3
- import { qkv as y } from "../ops/qkv.js";
4
- import { rope as m } from "../ops/rope.js";
5
- import { appendCache as b } from "../ops/appendCache.js";
6
- import { w as j, x as f, t as z } from "../index-pWA4_lUh.js";
7
- import { r as w, l as E, w as D, b as T } from "../random_width-oeUIlUZj.js";
8
- import { d as L, a as k } from "../exports_layers-tbTBcwMM.js";
9
- import { o as W } from "../ones-Cog-G2ag.js";
10
- import { z as M } from "../zeros-CCy9C3uU.js";
11
- import { v as A } from "../variable-Dl_ub3pk.js";
12
- import { s as g } from "../softmax-Be_lsqUc.js";
13
- import { m as _ } from "../mat_mul-D7_a4KJn.js";
14
- import { r as $ } from "../reshape-C8CR_Bad.js";
15
- class K extends x {
16
- constructor(s, t, i) {
17
- super(), this.ropeCache = i, this.config = t, this.index = s, this.units = t.nEmbed * 3, this.cProj = L({
18
- units: t.nEmbed,
19
- useBias: t.biasInLinear,
20
- name: `block_${s}_attn_cProj`,
21
- kernelInitializer: w({
22
- mean: 0,
23
- stddev: 0.02 / Math.sqrt(2 * t.nLayer)
24
- }),
25
- biasInitializer: "zeros"
26
- }), this.attnDropout = k({ rate: t.dropout }), this.residDropout = k({ rate: t.dropout }), this.bias = E.bandPart(W([t.blockSize, t.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(t.nEmbed / t.nHead);
27
- const e = M([t.blockSize, t.blockSize]), o = j([t.blockSize, t.blockSize], Number.NEGATIVE_INFINITY);
28
- this.maskInf = D(this.bias, e, o);
29
- }
30
- config;
1
+ import { attentionMask as I } from "../ops/attentionMask.js";
2
+ import y from "./BaseLayer.js";
3
+ import { qkv as z } from "../ops/qkv.js";
4
+ import { rope as P } from "../ops/rope.js";
5
+ import { appendCache as E } from "../ops/appendCache.js";
6
+ import { D as $, F as _, t as x, c as L, e as v, H as W } from "../index-XjBAhiFO.js";
7
+ import { fusedSoftmax as S } from "../ops/fusedSoftmax.js";
8
+ import { l as M, w as O, r as T, d as N, a as U } from "../tfjs_backend-Bzl2SrRo.js";
9
+ import { o as q } from "../ones-BO01zpJG.js";
10
+ import { z as B } from "../zeros-CmJFiC84.js";
11
+ import { v as g } from "../variable-BS4AKqNU.js";
12
+ import { m as C } from "../mat_mul-CKwFEV1Q.js";
13
+ import { r as D } from "../reshape-DFzh97Sc.js";
14
+ class nt extends y {
31
15
  cAttn = null;
32
- cProj;
33
- attnDropout;
34
- residDropout;
16
+ cProj = null;
35
17
  bias;
36
18
  maskInf;
37
19
  divisor;
38
20
  index;
39
21
  _trainable = !0;
40
22
  units;
23
+ projUnits;
24
+ constructor(t, s) {
25
+ super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = M.bandPart(q([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
26
+ const e = B([s.gpt.blockSize, s.gpt.blockSize]), n = $([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
27
+ this.maskInf = O(this.bias, e, n);
28
+ }
41
29
  build() {
42
- this.cAttn === null && (this.cAttn = A(
43
- T([this.config.nEmbed, this.units], 0, 0.02),
30
+ this.cAttn === null && (this.cAttn = g(
31
+ T([this.config.gpt.nEmbed, this.units], 0, 0.02),
44
32
  !0
45
33
  //`block_${this.index}_attn_cAttn_kernel`
34
+ )), this.cProj === null && (this.cProj = g(
35
+ T([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
36
+ !0
37
+ //`block_${this.index}_attn_cProj_kernel`
46
38
  ));
47
39
  }
48
40
  get variables() {
49
41
  if (this.cAttn === null)
50
42
  throw new Error("Layer not built yet");
51
- return [this.cAttn, ...this.cProj.trainableWeights.map((s) => s.read())];
43
+ return [this.cAttn, this.cProj];
52
44
  }
53
45
  get trainable() {
54
46
  return this._trainable;
55
47
  }
56
- set trainable(s) {
57
- this._trainable = s, this.cAttn && (this.cAttn.trainable = s), this.cProj.trainable = s;
48
+ set trainable(t) {
49
+ this._trainable = t, this.cAttn && (this.cAttn.trainable = t), this.cProj && (this.cProj.trainable = t);
58
50
  }
59
- saveWeights(s) {
60
- s.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), s.set(`block_${this.index}_cProj`, this.cProj.getWeights());
51
+ saveWeights(t) {
52
+ t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj ? [this.cProj.clone()] : []);
61
53
  }
62
- loadWeights(s) {
63
- const t = s.get(`block_${this.index}_cAttn`)?.[0];
64
- if (!t) throw new Error(`Weights for block_${this.index}_cAttn not found`);
65
- this.cAttn ? this.cAttn.assign(t) : this.cAttn = A(t, !0), this.cProj.setWeights(s.get(`block_${this.index}_cProj`) || []);
54
+ loadWeights(t) {
55
+ const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
56
+ if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
57
+ if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
58
+ this.cAttn ? this.cAttn.assign(s) : this.cAttn = g(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = g(e, !0);
66
59
  }
67
- getAttentionScores(s, t, i) {
68
- const e = C(s, t, this.maskInf, this.divisor), o = g(e, -1);
69
- return this.attnDropout.apply(o, { training: i });
60
+ getAttentionScores(t, s, e, n) {
61
+ const o = I(t, s, this.maskInf, this.divisor);
62
+ return S(o, e ? this.config.gpt.dropout : 0, n);
70
63
  }
71
64
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
72
- getAttentionScoresWithPast(s, t, i, e) {
73
- const o = s.shape[2];
74
- let n = _(s, t, !1, !0).mul(this.divisor);
75
- if (o > 1 && e > 0)
65
+ getAttentionScoresWithPast(t, s, e, n, o) {
66
+ const i = t.shape[2];
67
+ let r = C(t, s, !1, !0).mul(this.divisor);
68
+ if (i > 1 && n > 0)
76
69
  throw new Error("Cannot use past with T_cur > 1");
77
- if (o > 1) {
78
- const h = this.maskInf.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
79
- n = n.add(h);
70
+ if (i > 1) {
71
+ const c = this.maskInf.slice([0, 0], [i, i]).expandDims(0).expandDims(0);
72
+ r = r.add(c);
80
73
  }
81
- const a = g(n, -1);
82
- return this.attnDropout.apply(a, { training: i });
74
+ return S(r, e ? this.config.gpt.dropout : 0, o);
83
75
  }
84
- getQKV(s) {
85
- return y(s, this.cAttn, this.config.nHead);
76
+ getQKV(t) {
77
+ return z(t, this.cAttn, this.config.gpt.nHead);
86
78
  }
87
- getOutputProjection(s, t) {
88
- const i = s.shape[0], e = s.shape[2], o = this.config.nEmbed, r = s.transpose([0, 2, 1, 3]), n = $(r, [i, e, o]), a = this.cProj.apply(n);
89
- return this.residDropout.apply(a, { training: t });
79
+ getOutputProjection(t) {
80
+ const s = t.shape[0], e = t.shape[2], n = this.config.gpt.nEmbed, o = t.transpose([0, 2, 1, 3]), i = D(o, [s, e, n]);
81
+ return N(i, this.cProj);
90
82
  }
91
- updateCache(s, t, i) {
92
- const e = this.config.blockSize, o = s.shape[2], r = Math.min(i?.length || 0, e - o), n = i ? b(i.k, s, e) : s, a = i ? b(i.v, t, e) : t;
83
+ updateCache(t, s, e) {
84
+ const n = this.config.gpt.blockSize, o = t.shape[2], i = Math.min(e?.length || 0, n - o), a = e ? E(e.k, t, n) : t, r = e ? E(e.v, s, n) : s;
93
85
  return {
94
- k: f(n),
95
- v: f(a),
96
- length: r + o,
97
- cumulativeLength: i ? i.cumulativeLength + o : o
86
+ k: _(a),
87
+ v: _(r),
88
+ length: i + o,
89
+ cumulativeLength: e ? e.cumulativeLength + o : o
98
90
  };
99
91
  }
100
- // Added optional KV cache support (pastKV). Returns presentKV for chaining.
101
- call(s, t = !1, i = !1, e) {
102
- if (e && !this.config.useRope)
103
- throw new Error("Cannot use pastKV without RoPE enabled");
104
- return this.build(), z(() => {
92
+ forward(t, s = !1, e, n = !1, o) {
93
+ return x(() => {
105
94
  this.startMemory();
106
- const [o, r, n] = this.getQKV(s), a = e ? e.cumulativeLength : 0, h = this.ropeCache ? m(o, this.ropeCache, a) : o, p = this.ropeCache ? m(r, this.ropeCache, a) : r;
107
- this.ropeCache && (o.dispose(), r.dispose());
108
- const u = e ? e.length : 0, l = this.updateCache(p, n, e), d = l.k, v = l.v;
109
- e && (p.dispose(), n.dispose());
110
- let c;
111
- u > 0 ? c = this.getAttentionScoresWithPast(h, d, t, u) : c = this.getAttentionScores(h, d, t);
112
- const P = _(c, v), I = this.getOutputProjection(P, t), S = i ? c.mean(1) : void 0;
113
- return this.endMemory("CausalSelfAttention"), { output: I, attention: S, presentKV: l };
95
+ const [i, a, r] = this.getQKV(t), c = o ? o.cumulativeLength : 0, h = this.config.layerConfig.ropeCache, d = h ? P(i, h, c) : i, p = h ? P(a, h, c) : a;
96
+ h && (i.dispose(), a.dispose());
97
+ const f = o ? o.length : 0, l = this.updateCache(p, r, o), m = l.k, b = l.v;
98
+ o && (p.dispose(), r.dispose());
99
+ let u;
100
+ f > 0 ? u = this.getAttentionScoresWithPast(d, m, s, f, e) : u = this.getAttentionScores(d, m, s, e);
101
+ const k = C(u, b), A = this.getOutputProjection(k), w = n ? u.mean(1) : void 0;
102
+ return this.endMemory("CausalSelfAttention"), { output: A, attention: w, presentKV: l };
114
103
  });
115
104
  }
105
+ call(t, s = !1, e = !1, n) {
106
+ if (n && !this.config.gpt.useRope)
107
+ throw new Error("Cannot use pastKV without RoPE enabled");
108
+ if (s && n)
109
+ throw new Error("Cannot use pastKV during training");
110
+ if (t.shape.length !== 3)
111
+ throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
112
+ if (t.shape[2] !== this.config.gpt.nEmbed)
113
+ throw new Error(`Input tensor last dimension must be ${this.config.gpt.nEmbed}, got ${t.shape[2]}`);
114
+ this.build();
115
+ const o = Math.random() * 1e9;
116
+ if (s && this.config.layerConfig.checkpointAttention) {
117
+ const a = L(
118
+ // @ts-expect-error Invalid params
119
+ (r, c, h, d) => {
120
+ const p = this.forward(r, !0, o);
121
+ p.presentKV?.k.dispose(), p.presentKV?.v.dispose(), d([r]);
122
+ const f = (l, m) => {
123
+ const [b] = m, u = v().state.activeTape;
124
+ v().state.activeTape = [];
125
+ const k = W((A, w, H) => {
126
+ const j = this.forward(A, !0, o);
127
+ return j.presentKV?.k.dispose(), j.presentKV?.v.dispose(), j.output;
128
+ })([b, c, h], l);
129
+ return v().state.activeTape = u, k;
130
+ };
131
+ return { value: p.output, gradFunc: f };
132
+ }
133
+ )(t, this.cAttn, this.cProj);
134
+ if (this.config.gpt.dropout > 0) {
135
+ const r = U(a, this.config.gpt.dropout);
136
+ return a.dispose(), { output: r };
137
+ } else
138
+ return { output: a };
139
+ } else
140
+ return this.forward(t, s, o, e, n);
141
+ }
116
142
  dispose() {
117
- this.cAttn?.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose();
143
+ this.cAttn?.dispose(), this.cProj?.dispose(), this.bias.dispose(), this.maskInf.dispose();
118
144
  }
119
145
  }
120
146
  export {
121
- K as default
147
+ nt as default
122
148
  };
@@ -1,13 +1,12 @@
1
1
  import { Tensor, Variable } from '@tensorflow/tfjs-core';
2
- import { GPTConfig } from '../config';
3
- import { default as BaseLayer } from './BaseLayer';
2
+ import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
4
3
  export default class MLP extends BaseLayer {
5
4
  private cFc;
6
5
  private cProj;
7
6
  private dropout;
8
7
  private index;
9
8
  private _trainable;
10
- constructor(index: number, config: GPTConfig);
9
+ constructor(index: number, config: GPTLayerConfig);
11
10
  get variables(): Variable[];
12
11
  get trainable(): boolean;
13
12
  set trainable(value: boolean);
@@ -1,64 +1,7 @@
1
- import { t as n } from "../index-pWA4_lUh.js";
2
- import l from "./BaseLayer.js";
3
- import { r as s } from "../random_width-oeUIlUZj.js";
4
- import { d as i, a as c } from "../exports_layers-tbTBcwMM.js";
5
- class u extends l {
6
- cFc;
7
- cProj;
8
- dropout;
9
- index;
10
- _trainable = !0;
11
- constructor(t, e) {
12
- super(), this.index = t, this.cFc = i({
13
- units: e.mlpFactor * e.nEmbed,
14
- activation: "gelu",
15
- useBias: e.biasInLinear,
16
- kernelInitializer: s({
17
- mean: 0,
18
- stddev: 0.02
19
- }),
20
- biasInitializer: "zeros",
21
- name: `block_${t}_mlp_cFc`
22
- }), this.cProj = i({
23
- units: e.nEmbed,
24
- useBias: e.biasInLinear,
25
- kernelInitializer: s({
26
- mean: 0,
27
- stddev: 0.02 / Math.sqrt(2 * e.nLayer)
28
- }),
29
- biasInitializer: "zeros",
30
- name: `block_${t}_mlp_cProj`
31
- }), this.dropout = c({ rate: e.dropout });
32
- }
33
- get variables() {
34
- return [
35
- ...this.cFc.trainableWeights.map((t) => t.read()),
36
- ...this.cProj.trainableWeights.map((t) => t.read())
37
- ];
38
- }
39
- get trainable() {
40
- return this._trainable;
41
- }
42
- set trainable(t) {
43
- this._trainable = t, this.cFc.trainable = t, this.cProj.trainable = t;
44
- }
45
- saveWeights(t) {
46
- t.set(`block_${this.index}_mlpHidden`, this.cFc.getWeights()), t.set(`block_${this.index}_mlpOut`, this.cProj.getWeights());
47
- }
48
- loadWeights(t) {
49
- this.cFc.setWeights(t.get(`block_${this.index}_mlpHidden`) || []), this.cProj.setWeights(t.get(`block_${this.index}_mlpOut`) || []);
50
- }
51
- call(t, e = !1) {
52
- return n(() => {
53
- this.startMemory();
54
- const r = this.cFc.apply(t), a = this.cProj.apply(r), o = this.dropout.apply(a, { training: e });
55
- return this.endMemory("MLP"), o;
56
- });
57
- }
58
- dispose() {
59
- this.cFc.dispose(), this.cProj.dispose(), this.dropout.dispose();
60
- }
61
- }
1
+ import "../index-XjBAhiFO.js";
2
+ import "./BaseLayer.js";
3
+ import "../random_width-CMHmdbSu.js";
4
+ import { M as i } from "../MLP-KHhikThU.js";
62
5
  export {
63
- u as default
6
+ i as default
64
7
  };
@@ -1,9 +1,9 @@
1
1
  import { Tensor, Variable } from '@tensorflow/tfjs-core';
2
- import { default as BaseLayer } from './BaseLayer';
2
+ import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
3
3
  export default class RMSNorm extends BaseLayer {
4
4
  private gamma;
5
5
  private epsilon;
6
- constructor(shape: number[], epsilon?: number, name?: string);
6
+ constructor(config: GPTLayerConfig, epsilon?: number, name?: string);
7
7
  get trainableWeights(): Variable[];
8
8
  set trainable(value: boolean);
9
9
  getWeights(): Tensor[];
@@ -1,29 +1,29 @@
1
- import { t as r } from "../index-pWA4_lUh.js";
1
+ import { t as r } from "../index-XjBAhiFO.js";
2
2
  import m from "./BaseLayer.js";
3
- import { v as i } from "../variable-Dl_ub3pk.js";
4
- import { o } from "../ones-Cog-G2ag.js";
3
+ import { v as i } from "../variable-BS4AKqNU.js";
4
+ import { o } from "../ones-BO01zpJG.js";
5
5
  class d extends m {
6
6
  gamma;
7
7
  epsilon;
8
- constructor(a, s = 1e-8, t = "") {
9
- super(), this.epsilon = s, this.gamma = i(o(a), !0, `${t}_gamma`, "float32");
8
+ constructor(t, s = 1e-8, a = "") {
9
+ super(t), this.epsilon = s, this.gamma = i(o([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
10
10
  }
11
11
  get trainableWeights() {
12
12
  return [this.gamma];
13
13
  }
14
- set trainable(a) {
15
- this.gamma.trainable = a;
14
+ set trainable(t) {
15
+ this.gamma.trainable = t;
16
16
  }
17
17
  getWeights() {
18
18
  return [this.gamma];
19
19
  }
20
- setWeights(a) {
21
- this.gamma.assign(a[0]);
20
+ setWeights(t) {
21
+ this.gamma.assign(t[0]);
22
22
  }
23
- apply(a) {
23
+ apply(t) {
24
24
  return r(() => {
25
25
  this.startMemory();
26
- const t = a.square().mean(-1, !0).add(this.epsilon).rsqrt(), e = a.mul(t).mul(this.gamma);
26
+ const a = t.square().mean(-1, !0).add(this.epsilon).rsqrt(), e = t.mul(a).mul(this.gamma);
27
27
  return this.endMemory("RMSNorm"), e;
28
28
  });
29
29
  }
@@ -1,6 +1,6 @@
1
- import { o as h, h as c, E as f, I as l, f as n, J as m, t as u, x as p } from "../index-pWA4_lUh.js";
2
- import { c as d, s as C } from "../sin-BJIrfnj7.js";
3
- import { r as a } from "../range-CcDl05lo.js";
1
+ import { o as h, h as c, E as f, N as l, f as n, O as m, t as u, F as p } from "../index-XjBAhiFO.js";
2
+ import { c as d, s as C } from "../sin-BYM-U4Ut.js";
3
+ import { r as a } from "../range-DQMNzBWs.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,8 +1,9 @@
1
- import { r as t, d as s } from "../random_width-oeUIlUZj.js";
2
- import "../index-pWA4_lUh.js";
3
- import { v as r } from "../variable-Dl_ub3pk.js";
4
- import { g as d } from "../gather-BPGW8RsB.js";
5
- class b {
1
+ import { r as t } from "../random_width-CMHmdbSu.js";
2
+ import "../index-XjBAhiFO.js";
3
+ import { d as s } from "../tfjs_backend-Bzl2SrRo.js";
4
+ import { v as r } from "../variable-BS4AKqNU.js";
5
+ import { g as d } from "../gather-BUmJIS8n.js";
6
+ class n {
6
7
  vocabSize;
7
8
  embedDim;
8
9
  tiedWeights;
@@ -43,5 +44,5 @@ class b {
43
44
  }
44
45
  }
45
46
  export {
46
- b as default
47
+ n as default
47
48
  };
@@ -1,8 +1,5 @@
1
- import { GPTConfig } from '../config';
2
1
  import { KVCache } from './CausalSelfAttention';
3
- import { default as RoPECache } from './RoPECache';
4
- import { default as MemoryProfiler } from '../utilities/profile';
5
- import { default as BaseLayer } from './BaseLayer';
2
+ import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
6
3
  import { Tensor, Variable } from '@tensorflow/tfjs-core';
7
4
  export default class Block extends BaseLayer {
8
5
  private ln1;
@@ -12,8 +9,7 @@ export default class Block extends BaseLayer {
12
9
  private index;
13
10
  private _trainable;
14
11
  skipped: boolean;
15
- constructor(index: number, config: GPTConfig, ropeCache?: RoPECache);
16
- setProfiler(value: MemoryProfiler | undefined): void;
12
+ constructor(index: number, config: GPTLayerConfig);
17
13
  get variables(): Variable[];
18
14
  get trainable(): boolean;
19
15
  set trainable(value: boolean);
@@ -1,9 +1,9 @@
1
1
  import h from "./CausalSelfAttention.js";
2
- import o from "./MLP.js";
3
- import r from "./RMSNorm.js";
2
+ import { M as o } from "../MLP-KHhikThU.js";
3
+ import a from "./RMSNorm.js";
4
4
  import p from "./BaseLayer.js";
5
- import { t as d } from "../index-pWA4_lUh.js";
6
- class g extends p {
5
+ import { t as d } from "../index-XjBAhiFO.js";
6
+ class W extends p {
7
7
  ln1;
8
8
  attn;
9
9
  ln2;
@@ -11,11 +11,8 @@ class g extends p {
11
11
  index;
12
12
  _trainable = !0;
13
13
  skipped = !1;
14
- constructor(t, s, i) {
15
- super(), this.index = t, this.ln1 = new r([s.nEmbed], 1e-8, `block_${this.index}_rms1`), this.attn = new h(this.index, s, i), this.ln2 = new r([s.nEmbed], 1e-8, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
16
- }
17
- setProfiler(t) {
18
- this._profiler = t, this.attn.setProfiler(t), this.mlp.setProfiler(t), this.ln1.setProfiler(t), this.ln2.setProfiler(t);
14
+ constructor(t, s) {
15
+ super(s), this.index = t, this.ln1 = new a(s, 1e-8, `block_${this.index}_rms1`), this.attn = new h(this.index, s), this.ln2 = new a(s, 1e-8, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
19
16
  }
20
17
  get variables() {
21
18
  return [
@@ -45,9 +42,9 @@ class g extends p {
45
42
  return d(() => {
46
43
  if (this.skipped)
47
44
  return { output: t };
48
- const l = this.ln1.apply(t), n = this.attn.call(l, s, i, e), a = t.add(n.output);
45
+ const l = this.ln1.apply(t), n = this.attn.call(l, s, i, e), r = t.add(n.output);
49
46
  return {
50
- output: this.getMLPOutput(a, s),
47
+ output: this.getMLPOutput(r, s),
51
48
  attention: n.attention,
52
49
  cache: n.presentKV
53
50
  };
@@ -58,5 +55,5 @@ class g extends p {
58
55
  }
59
56
  }
60
57
  export {
61
- g as default
58
+ W as default
62
59
  };