@genai-fi/nanogpt 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. package/dist/Generator.js +22 -22
  2. package/dist/MLP-KHhikThU.js +83 -0
  3. package/dist/NanoGPTModel.d.ts +2 -3
  4. package/dist/NanoGPTModel.js +79 -79
  5. package/dist/TeachableLLM.js +16 -13
  6. package/dist/axis_util-DeydwOoC.js +69 -0
  7. package/dist/{concat-BIZS_td9.js → concat-DS_qH7MI.js} +5 -5
  8. package/dist/config.js +7 -8
  9. package/dist/{gather-BPGW8RsB.js → gather-BUmJIS8n.js} +1 -1
  10. package/dist/{index-pWA4_lUh.js → index-XjBAhiFO.js} +1272 -1174
  11. package/dist/layers/BaseLayer.d.ts +14 -2
  12. package/dist/layers/BaseLayer.js +9 -9
  13. package/dist/layers/CausalSelfAttention.d.ts +4 -8
  14. package/dist/layers/CausalSelfAttention.js +108 -82
  15. package/dist/layers/MLP.d.ts +2 -3
  16. package/dist/layers/MLP.js +5 -62
  17. package/dist/layers/RMSNorm.d.ts +2 -2
  18. package/dist/layers/RMSNorm.js +11 -11
  19. package/dist/layers/RoPECache.js +3 -3
  20. package/dist/layers/TiedEmbedding.js +7 -6
  21. package/dist/layers/TransformerBlock.d.ts +2 -6
  22. package/dist/layers/TransformerBlock.js +9 -12
  23. package/dist/{sum-C7Mgy9Bw.js → log_sum_exp-DJPkVZZn.js} +32 -54
  24. package/dist/main.js +22 -19
  25. package/dist/{mat_mul-D7_a4KJn.js → mat_mul-CKwFEV1Q.js} +1 -1
  26. package/dist/max-DJvEiCAJ.js +25 -0
  27. package/dist/moments-CrWRPcR3.js +53 -0
  28. package/dist/norm-BzY929B_.js +86 -0
  29. package/dist/{ones-Cog-G2ag.js → ones-BO01zpJG.js} +2 -2
  30. package/dist/ops/appendCache.js +1 -1
  31. package/dist/ops/attentionMask.js +1 -1
  32. package/dist/ops/cpu/appendCache.js +2 -2
  33. package/dist/ops/cpu/attentionMask.js +2 -2
  34. package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
  35. package/dist/ops/cpu/fusedSoftmax.js +23 -0
  36. package/dist/ops/cpu/gatherSub.js +3 -3
  37. package/dist/ops/cpu/mulDropout.d.ts +1 -0
  38. package/dist/ops/cpu/mulDropout.js +17 -0
  39. package/dist/ops/cpu/qkv.js +3 -3
  40. package/dist/ops/cpu/rope.js +5 -5
  41. package/dist/ops/cpu/scatterSub.js +27 -27
  42. package/dist/ops/fusedSoftmax.d.ts +2 -0
  43. package/dist/ops/fusedSoftmax.js +10 -0
  44. package/dist/ops/gatherSub.js +1 -1
  45. package/dist/ops/grads/attentionMask.js +1 -1
  46. package/dist/ops/grads/fusedSoftmax.d.ts +2 -0
  47. package/dist/ops/grads/fusedSoftmax.js +17 -0
  48. package/dist/ops/grads/qkv.js +1 -1
  49. package/dist/ops/grads/rope.js +1 -1
  50. package/dist/ops/mulDrop.d.ts +2 -0
  51. package/dist/ops/mulDrop.js +9 -0
  52. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  53. package/dist/ops/qkv.js +1 -1
  54. package/dist/ops/scatterSub.js +1 -1
  55. package/dist/ops/webgl/appendCache.js +1 -1
  56. package/dist/ops/webgl/attentionMask.js +1 -1
  57. package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
  58. package/dist/ops/webgl/fusedSoftmax.js +3930 -0
  59. package/dist/ops/webgl/gatherSub.js +1 -1
  60. package/dist/ops/webgl/mulDropout.d.ts +1 -0
  61. package/dist/ops/webgl/mulDropout.js +41 -0
  62. package/dist/ops/webgl/qkv.js +1 -1
  63. package/dist/ops/webgl/rope.js +1 -1
  64. package/dist/ops/webgl/scatterSub.js +1 -1
  65. package/dist/{random_width-oeUIlUZj.js → random_width-CMHmdbSu.js} +4212 -6630
  66. package/dist/{range-CcDl05lo.js → range-DQMNzBWs.js} +1 -1
  67. package/dist/{reshape-C8CR_Bad.js → reshape-DFzh97Sc.js} +1 -1
  68. package/dist/{sin-BJIrfnj7.js → sin-BYM-U4Ut.js} +1 -1
  69. package/dist/slice_util-CnVNPQI-.js +90 -0
  70. package/dist/softmax-4DOn6cPq.js +28 -0
  71. package/dist/{split-DZbvruEP.js → split-CkbeVdF8.js} +3 -3
  72. package/dist/{stack-BMm-efee.js → stack-DaIMO5iX.js} +1 -1
  73. package/dist/sum-C6u3xMi3.js +27 -0
  74. package/dist/{tensor-DJVbYhh1.js → tensor-Cu1fU7H7.js} +1 -1
  75. package/dist/{tensor2d-ZuQSh2D-.js → tensor2d-D0CKdG6B.js} +1 -1
  76. package/dist/tfjs_backend-Bzl2SrRo.js +2460 -0
  77. package/dist/training/AdamExt.js +1 -1
  78. package/dist/training/DatasetBuilder.js +3 -3
  79. package/dist/training/FullTrainer.js +1 -1
  80. package/dist/training/Trainer.js +13 -12
  81. package/dist/training/sparseCrossEntropy.js +12 -11
  82. package/dist/utilities/dummy.js +8 -8
  83. package/dist/utilities/generate.js +11 -11
  84. package/dist/utilities/load.js +1 -1
  85. package/dist/utilities/profile.js +1 -1
  86. package/dist/utilities/weights.js +2 -2
  87. package/dist/{variable-Dl_ub3pk.js → variable-BS4AKqNU.js} +1 -1
  88. package/dist/{zeros-CCy9C3uU.js → zeros-CmJFiC84.js} +1 -1
  89. package/package.json +1 -1
  90. package/dist/exports_layers-tbTBcwMM.js +0 -25
  91. package/dist/layers/LayerNorm.d.ts +0 -13
  92. package/dist/layers/LayerNorm.js +0 -33
  93. package/dist/moments-DfcpfwKi.js +0 -132
  94. package/dist/softmax-Be_lsqUc.js +0 -105
  95. package/dist/training/LayerTrainer.d.ts +0 -29
  96. package/dist/training/LayerTrainer.js +0 -95
  97. package/dist/training/lwSchedule.d.ts +0 -7
  98. package/dist/training/lwSchedule.js +0 -162
package/dist/Generator.js CHANGED
@@ -1,7 +1,7 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-pWA4_lUh.js";
3
- import { t as d } from "./tensor2d-ZuQSh2D-.js";
4
- import { c as k } from "./concat-BIZS_td9.js";
2
+ import "./index-XjBAhiFO.js";
3
+ import { t as d } from "./tensor2d-D0CKdG6B.js";
4
+ import { c as p } from "./concat-DS_qH7MI.js";
5
5
  class w extends u {
6
6
  constructor(s, e) {
7
7
  super(), this.model = s, this.tokeniser = e;
@@ -12,41 +12,41 @@ class w extends u {
12
12
  return d(e, [1, e[0].length], "int32");
13
13
  }
14
14
  async generateNoCache(s, e) {
15
- let t = await this.tokenisePrompt(s), o = s || "";
16
- const n = e?.maxLength ?? 1e3;
17
- for (let a = 0; a < n && this.active; a++) {
15
+ let t = await this.tokenisePrompt(s), n = s || "";
16
+ const o = e?.maxLength ?? 1e3;
17
+ for (let a = 0; a < o && this.active; a++) {
18
18
  const {
19
19
  output: i,
20
20
  attention: c,
21
21
  probabilities: l
22
22
  } = this.model.generate(t, void 0, e), h = t;
23
- t = k([t, i], 1), h.dispose();
23
+ t = p([t, i], 1), h.dispose();
24
24
  const r = await this.processResponse(i, c, l);
25
25
  if (i.dispose(), r === null)
26
26
  break;
27
- o += r;
27
+ n += r;
28
28
  }
29
- return t.dispose(), o;
29
+ return t.dispose(), n;
30
30
  }
31
31
  async processResponse(s, e, t) {
32
- const o = (await s.array())[0][0];
33
- if (o === this.tokeniser.eosToken)
32
+ const n = (await s.array())[0][0];
33
+ if (n === this.tokeniser.eosToken)
34
34
  return null;
35
- const n = await this.tokeniser.decode([o]);
35
+ const o = await this.tokeniser.decode([n]);
36
36
  let a;
37
37
  e && (a = await e.array(), e.dispose());
38
38
  let i;
39
- return t && (i = await t.array(), t.dispose()), this.emit("tokens", [o], n, a, i), n;
39
+ return t && (i = await t.array(), t.dispose()), this.emit("tokens", [n], o, a, i), o;
40
40
  }
41
41
  async generateCache(s, e) {
42
- let t = await this.tokenisePrompt(s), o = s || "";
43
- const n = new Array(this.model.config.nLayer).fill(void 0), a = e?.maxLength ?? 1e3;
42
+ let t = await this.tokenisePrompt(s), n = s || "";
43
+ const o = new Array(this.model.config.gpt.nLayer).fill(void 0), a = e?.maxLength ?? 1e3;
44
44
  for (let i = 0; i < a && this.active; i++) {
45
45
  const {
46
46
  output: c,
47
47
  attention: l,
48
48
  probabilities: h
49
- } = this.model.generate(t, n, {
49
+ } = this.model.generate(t, o, {
50
50
  ...e,
51
51
  usePadding: !1
52
52
  });
@@ -54,17 +54,17 @@ class w extends u {
54
54
  const r = await this.processResponse(c, l, h);
55
55
  if (r === null)
56
56
  break;
57
- o += r;
57
+ n += r;
58
58
  }
59
- return n.forEach((i) => {
59
+ return o.forEach((i) => {
60
60
  i && (i.k.dispose(), i.v.dispose());
61
- }), t.dispose(), o;
61
+ }), t.dispose(), n;
62
62
  }
63
63
  async generate(s, e) {
64
- const t = s && s.length > this.model.config.blockSize ? s.slice(-this.model.config.blockSize) : s;
64
+ const t = s && s.length > this.model.config.gpt.blockSize ? s.slice(-this.model.config.gpt.blockSize) : s;
65
65
  this.active = !0, this.emit("start");
66
- const n = await (this.model.config.useRope && !e?.noCache ? this.generateCache(t, e) : this.generateNoCache(t, e));
67
- return this.active = !1, this.emit("stop"), n;
66
+ const o = await (this.model.config.gpt.useRope && !e?.noCache && !e?.includeAttention ? this.generateCache(t, e) : this.generateNoCache(t, e));
67
+ return this.active = !1, this.emit("stop"), o;
68
68
  }
69
69
  stop() {
70
70
  this.active = !1;
@@ -0,0 +1,83 @@
1
+ import { t as d } from "./index-XjBAhiFO.js";
2
+ import c from "./layers/BaseLayer.js";
3
+ import { E as p, D as l, a as h, r as i } from "./random_width-CMHmdbSu.js";
4
+ /**
5
+ * @license
6
+ * Copyright 2018 Google LLC
7
+ *
8
+ * Use of this source code is governed by an MIT-style
9
+ * license that can be found in the LICENSE file or at
10
+ * https://opensource.org/licenses/MIT.
11
+ * =============================================================================
12
+ */
13
+ function r(s) {
14
+ return new h(s);
15
+ }
16
+ function u(s) {
17
+ return new l(s);
18
+ }
19
+ function g(s) {
20
+ return new p(s);
21
+ }
22
+ class P extends c {
23
+ cFc;
24
+ cProj;
25
+ dropout;
26
+ index;
27
+ _trainable = !0;
28
+ constructor(t, e) {
29
+ super(e), this.index = t, this.cFc = r({
30
+ units: e.gpt.mlpFactor * e.gpt.nEmbed,
31
+ activation: "gelu",
32
+ useBias: e.gpt.biasInLinear,
33
+ kernelInitializer: i({
34
+ mean: 0,
35
+ stddev: 0.02
36
+ }),
37
+ biasInitializer: "zeros",
38
+ name: `block_${t}_mlp_cFc`
39
+ }), this.cProj = r({
40
+ units: e.gpt.nEmbed,
41
+ useBias: e.gpt.biasInLinear,
42
+ kernelInitializer: i({
43
+ mean: 0,
44
+ stddev: 0.02 / Math.sqrt(2 * e.gpt.nLayer)
45
+ }),
46
+ biasInitializer: "zeros",
47
+ name: `block_${t}_mlp_cProj`
48
+ }), this.dropout = u({ rate: e.gpt.dropout });
49
+ }
50
+ get variables() {
51
+ return [
52
+ ...this.cFc.trainableWeights.map((t) => t.read()),
53
+ ...this.cProj.trainableWeights.map((t) => t.read())
54
+ ];
55
+ }
56
+ get trainable() {
57
+ return this._trainable;
58
+ }
59
+ set trainable(t) {
60
+ this._trainable = t, this.cFc.trainable = t, this.cProj.trainable = t;
61
+ }
62
+ saveWeights(t) {
63
+ t.set(`block_${this.index}_mlpHidden`, this.cFc.getWeights()), t.set(`block_${this.index}_mlpOut`, this.cProj.getWeights());
64
+ }
65
+ loadWeights(t) {
66
+ this.cFc.setWeights(t.get(`block_${this.index}_mlpHidden`) || []), this.cProj.setWeights(t.get(`block_${this.index}_mlpOut`) || []);
67
+ }
68
+ call(t, e = !1) {
69
+ return d(() => {
70
+ this.startMemory();
71
+ const a = this.cFc.apply(t), n = this.cProj.apply(a), o = this.dropout.apply(n, { training: e });
72
+ return this.endMemory("MLP"), o;
73
+ });
74
+ }
75
+ dispose() {
76
+ this.cFc.dispose(), this.cProj.dispose(), this.dropout.dispose();
77
+ }
78
+ }
79
+ export {
80
+ P as M,
81
+ u as d,
82
+ g as e
83
+ };
@@ -1,6 +1,5 @@
1
1
  import { GPTConfig } from './config';
2
2
  import { KVCache } from './layers/CausalSelfAttention';
3
- import { default as MemoryProfiler } from './utilities/profile';
4
3
  import { default as BaseLayer } from './layers/BaseLayer';
5
4
  import { Tensor, Variable } from '@tensorflow/tfjs-core';
6
5
  export interface TrainingLogEntry {
@@ -19,7 +18,6 @@ export interface GenerateOptions {
19
18
  includeProbabilities?: boolean;
20
19
  }
21
20
  export default class NanoGPT extends BaseLayer {
22
- readonly config: GPTConfig;
23
21
  private wte;
24
22
  private wpe?;
25
23
  private drop;
@@ -28,6 +26,8 @@ export default class NanoGPT extends BaseLayer {
28
26
  private ropeCache?;
29
27
  log: TrainingLogEntry[];
30
28
  constructor(config?: Partial<GPTConfig>);
29
+ get checkpointing(): boolean;
30
+ set checkpointing(value: boolean);
31
31
  get variables(): Variable[];
32
32
  saveWeights(): Map<string, Tensor[]>;
33
33
  loadWeights(weights: Map<string, Tensor[]>): void;
@@ -35,7 +35,6 @@ export default class NanoGPT extends BaseLayer {
35
35
  setSkipMask(mask: boolean[]): void;
36
36
  setTrainableMask(mask: boolean[]): void;
37
37
  set trainable(value: boolean);
38
- setProfiler(value: MemoryProfiler | undefined): void;
39
38
  private validateInput;
40
39
  private calculateLoss;
41
40
  private computeAttentionRollout;
@@ -1,18 +1,19 @@
1
- import { defaultConfig as F } from "./config.js";
2
- import L from "./layers/TransformerBlock.js";
3
- import P from "./layers/TiedEmbedding.js";
4
- import C from "./layers/RoPECache.js";
1
+ import { defaultConfig as x } from "./config.js";
2
+ import W from "./layers/TransformerBlock.js";
3
+ import F from "./layers/TiedEmbedding.js";
4
+ import P from "./layers/RoPECache.js";
5
5
  import q from "./layers/RMSNorm.js";
6
6
  import { estimateParameterCount as K } from "./utilities/parameters.js";
7
7
  import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
8
8
  import T from "./layers/BaseLayer.js";
9
- import { r as R, e as D, p as A } from "./random_width-oeUIlUZj.js";
10
- import { o as y, h as E, p as B, E as z, W as G, X as O, Y as Q, t as w, Z as X, f as _ } from "./index-pWA4_lUh.js";
11
- import { e as j, a as U } from "./exports_layers-tbTBcwMM.js";
12
- import { r as S } from "./reshape-C8CR_Bad.js";
13
- import { r as V } from "./range-CcDl05lo.js";
14
- import { g as Y } from "./gather-BPGW8RsB.js";
15
- import { s as Z } from "./softmax-Be_lsqUc.js";
9
+ import { r as R, p as D } from "./random_width-CMHmdbSu.js";
10
+ import { o as y, h as $, p as A, E as v, a6 as B, a7 as G, a8 as O, t as w, a5 as Q, f as C } from "./index-XjBAhiFO.js";
11
+ import { e as j, d as U } from "./MLP-KHhikThU.js";
12
+ import { r as _ } from "./reshape-DFzh97Sc.js";
13
+ import { r as V } from "./range-DQMNzBWs.js";
14
+ import { e as X } from "./tfjs_backend-Bzl2SrRo.js";
15
+ import { g as H } from "./gather-BUmJIS8n.js";
16
+ import { s as J } from "./softmax-4DOn6cPq.js";
16
17
  /**
17
18
  * @license
18
19
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -29,13 +30,13 @@ import { s as Z } from "./softmax-Be_lsqUc.js";
29
30
  * limitations under the License.
30
31
  * =============================================================================
31
32
  */
32
- function H(m, t) {
33
- let e = E(m, "a", "mod"), o = E(t, "b", "mod");
34
- [e, o] = B(e, o);
33
+ function Y(f, t) {
34
+ let e = $(f, "a", "mod"), o = $(t, "b", "mod");
35
+ [e, o] = A(e, o);
35
36
  const i = { a: e, b: o };
36
- return z.runKernel(G, i);
37
+ return v.runKernel(B, i);
37
38
  }
38
- const J = /* @__PURE__ */ y({ mod_: H });
39
+ const Z = /* @__PURE__ */ y({ mod_: Y });
39
40
  /**
40
41
  * @license
41
42
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -52,17 +53,17 @@ const J = /* @__PURE__ */ y({ mod_: H });
52
53
  * limitations under the License.
53
54
  * =============================================================================
54
55
  */
55
- function tt(m, t, e, o = !1) {
56
- const i = E(m, "logits", "multinomial"), s = i.size, r = i.rank;
56
+ function tt(f, t, e, o = !1) {
57
+ const i = $(f, "logits", "multinomial"), s = i.size, r = i.rank;
57
58
  if (s < 2)
58
59
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
59
60
  if (r > 2)
60
61
  throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
61
62
  e = e || Math.random();
62
- const n = { logits: r === 1 ? S(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = z.runKernel(O, n, h);
63
- return r === 1 ? S(a, [a.size]) : a;
63
+ const n = { logits: r === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, l = v.runKernel(G, n, h);
64
+ return r === 1 ? _(l, [l.size]) : l;
64
65
  }
65
- const I = /* @__PURE__ */ y({ multinomial_: tt });
66
+ const M = /* @__PURE__ */ y({ multinomial_: tt });
66
67
  /**
67
68
  * @license
68
69
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -79,8 +80,8 @@ const I = /* @__PURE__ */ y({ multinomial_: tt });
79
80
  * limitations under the License.
80
81
  * =============================================================================
81
82
  */
82
- function et(m, t = 1, e = !0) {
83
- const o = E(m, "x", "topk");
83
+ function et(f, t = 1, e = !0) {
84
+ const o = $(f, "x", "topk");
84
85
  if (o.rank === 0)
85
86
  throw new Error("topk() expects the input to be of rank 1 or higher");
86
87
  const i = o.shape[o.shape.length - 1];
@@ -88,12 +89,11 @@ function et(m, t = 1, e = !0) {
88
89
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
89
90
  if (t > i)
90
91
  throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
91
- const s = { x: o }, r = { k: t, sorted: e }, [l, n] = z.runKernel(Q, s, r);
92
- return { values: l, indices: n };
92
+ const s = { x: o }, r = { k: t, sorted: e }, [a, n] = v.runKernel(O, s, r);
93
+ return { values: a, indices: n };
93
94
  }
94
95
  const ot = /* @__PURE__ */ y({ topk_: et });
95
- class kt extends T {
96
- config;
96
+ class wt extends T {
97
97
  wte;
98
98
  // Token embeddings
99
99
  wpe;
@@ -107,19 +107,25 @@ class kt extends T {
107
107
  log = [];
108
108
  // Training log
109
109
  constructor(t = {}) {
110
- super(), this.config = { ...F, ...t }, this.wte = new P({
111
- vocabSize: this.config.vocabSize,
112
- embedDim: this.config.nEmbed,
110
+ super({ gpt: { ...x, ...t }, layerConfig: {} }), this.wte = new F({
111
+ vocabSize: this.config.gpt.vocabSize,
112
+ embedDim: this.config.gpt.nEmbed,
113
113
  name: "token_embedding"
114
- }), this.config.useRope === !1 ? this.wpe = j({
115
- inputDim: this.config.blockSize,
116
- outputDim: this.config.nEmbed,
114
+ }), this.config.gpt.useRope === !1 ? this.wpe = j({
115
+ inputDim: this.config.gpt.blockSize,
116
+ outputDim: this.config.gpt.nEmbed,
117
117
  name: "positional_embedding",
118
118
  embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
119
- }) : this.ropeCache = new C(this.config), this.drop = U({ rate: this.config.dropout }), this.blocks = [];
120
- for (let e = 0; e < this.config.nLayer; e++)
121
- this.blocks.push(new L(e, this.config, this.ropeCache));
122
- this.lnF = new q([this.config.nEmbed], 1e-8, "final_rms_norm");
119
+ }) : (this.ropeCache = new P(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = U({ rate: this.config.gpt.dropout }), this.blocks = [];
120
+ for (let e = 0; e < this.config.gpt.nLayer; e++)
121
+ this.blocks.push(new W(e, this.config));
122
+ this.lnF = new q(this.config, 1e-8, "final_rms_norm");
123
+ }
124
+ get checkpointing() {
125
+ return this.config.layerConfig.checkpointAttention === !0 || this.config.layerConfig.checkpointMLP === !0;
126
+ }
127
+ set checkpointing(t) {
128
+ this.config.layerConfig.checkpointAttention = t, this.config.layerConfig.checkpointMLP = t;
123
129
  }
124
130
  get variables() {
125
131
  return [
@@ -145,9 +151,9 @@ class kt extends T {
145
151
  inputPhase(t, e, o = !1) {
146
152
  return w(() => {
147
153
  const i = this.wte.embed(t);
148
- if (this.config.useRope === !1) {
149
- const [, s] = t.shape, r = this.config.blockSize, l = V(0, s, 1, "int32"), n = J(X(l, _(e, "int32")), _(r, "int32")), h = this.wpe.apply(n), a = i.add(h);
150
- return this.drop.apply(a, { training: o });
154
+ if (this.config.gpt.useRope === !1) {
155
+ const [, s] = t.shape, r = this.config.gpt.blockSize, a = V(0, s, 1, "int32"), n = Z(Q(a, C(e, "int32")), C(r, "int32")), h = this.wpe.apply(n), l = i.add(h);
156
+ return this.drop.apply(l, { training: o });
151
157
  } else
152
158
  return this.drop.apply(i, { training: o });
153
159
  });
@@ -169,17 +175,11 @@ class kt extends T {
169
175
  e.trainable = t;
170
176
  this.lnF.trainable = t;
171
177
  }
172
- setProfiler(t) {
173
- this._profiler = t;
174
- for (const e of this.blocks)
175
- e.setProfiler(t);
176
- this.lnF.setProfiler(t);
177
- }
178
178
  validateInput(t) {
179
179
  if (t.shape.length !== 2)
180
180
  throw new Error(`Invalid input shape: expected [batch_size, sequence_length], got ${t.shape}`);
181
- if (t.shape[1] > this.config.blockSize)
182
- throw new Error(`Input sequence length ${t.shape[1]} isn't block size ${this.config.blockSize}`);
181
+ if (t.shape[1] > this.config.gpt.blockSize)
182
+ throw new Error(`Input sequence length ${t.shape[1]} isn't block size ${this.config.gpt.blockSize}`);
183
183
  if (t.dtype !== "int32")
184
184
  throw new Error(`Input tensor must be of type int32, got ${t.dtype}`);
185
185
  }
@@ -198,17 +198,17 @@ class kt extends T {
198
198
  throw new Error("No attentions for rollout");
199
199
  const [e, o, i] = t[0].shape;
200
200
  for (const s of t) {
201
- const [r, l, n] = s.shape;
202
- if (r !== e || l !== o || n !== i)
201
+ const [r, a, n] = s.shape;
202
+ if (r !== e || a !== o || n !== i)
203
203
  throw new Error(
204
- `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${r},${l},${n}]`
204
+ `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${r},${a},${n}]`
205
205
  );
206
206
  }
207
207
  if (o === i) {
208
- const s = D(i, i).expandDims(0);
208
+ const s = X(i, i).expandDims(0);
209
209
  let r = s.tile([e, 1, 1]);
210
- for (const l of t) {
211
- const n = l.add(s);
210
+ for (const a of t) {
211
+ const n = a.add(s);
212
212
  r = n.div(n.sum(-1, !0)).matMul(r);
213
213
  }
214
214
  return r;
@@ -220,52 +220,52 @@ class kt extends T {
220
220
  return this.validateInput(t), w(() => {
221
221
  this.startMemory();
222
222
  const r = s?.[0]?.length ?? 0;
223
- let l = this.inputPhase(t, r, o);
223
+ let a = this.inputPhase(t, r, o);
224
224
  const n = [];
225
225
  if (s && s.length !== this.blocks.length)
226
226
  throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
227
- for (let c = 0; c < this.blocks.length; c++) {
228
- const u = l, d = this.blocks[c], {
227
+ for (let p = 0; p < this.blocks.length; p++) {
228
+ const m = a, u = this.blocks[p], {
229
229
  output: b,
230
230
  attention: k,
231
- cache: f
232
- } = d.call(l, o, i, s ? s[c] : void 0);
233
- l = b, u.dispose(), i && k && n.push(k), s && f ? (s[c]?.k.dispose(), s[c]?.v.dispose(), s[c] = f) : f && (f.k.dispose(), f.v.dispose());
231
+ cache: g
232
+ } = u.call(a, o, i, s ? s[p] : void 0);
233
+ a = b, m.dispose(), i && k && n.push(k), s && g ? (s[p]?.k.dispose(), s[p]?.v.dispose(), s[p] = g) : g && (g.k.dispose(), g.v.dispose());
234
234
  }
235
235
  let h;
236
- i && n.length > 0 && (h = this.computeAttentionRollout(n)), l = this.lnF.apply(l);
237
- const a = this.wte.project(l);
238
- let p;
239
- return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
236
+ i && n.length > 0 && (h = this.computeAttentionRollout(n)), a = this.lnF.apply(a);
237
+ const l = this.wte.project(a);
238
+ let c;
239
+ return e && (c = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: c, attention: i ? h : void 0 };
240
240
  });
241
241
  }
242
242
  generate(t, e, o) {
243
- const i = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, l = o?.includeAttention ?? !1;
243
+ const i = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, a = o?.includeAttention ?? !1;
244
244
  return w(() => {
245
- const n = t, h = n.shape[1], a = h <= this.config.blockSize ? n : n.slice(
246
- [0, h - this.config.blockSize],
247
- [n.shape[0], this.config.blockSize]
248
- ), p = r ? this.config.blockSize - a.shape[1] : 0, c = p > 0 ? A(a, [
245
+ const n = t, h = n.shape[1], l = h <= this.config.gpt.blockSize ? n : n.slice(
246
+ [0, h - this.config.gpt.blockSize],
247
+ [n.shape[0], this.config.gpt.blockSize]
248
+ ), c = r ? this.config.gpt.blockSize - l.shape[1] : 0, p = c > 0 ? D(l, [
249
249
  [0, 0],
250
- [0, p]
251
- ]) : a, { logits: u, attention: d } = this.forward(c, void 0, !1, l, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = d ? d.slice([0, b, 0], [d.shape[0], 1, d.shape[2]]) : void 0, $ = k.div(i);
252
- let g;
250
+ [0, c]
251
+ ]) : l, { logits: m, attention: u } = this.forward(p, void 0, !1, a, e), b = m.shape[1] - 1 - c, k = m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]), g = u ? u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]) : void 0, E = k.div(i);
252
+ let d;
253
253
  if (s) {
254
- const { values: M, indices: x } = ot($, s), W = I(M.squeeze([1]), 1);
255
- g = Y(x.squeeze([1]), W, 1);
254
+ const { values: S, indices: I } = ot(E, s), L = M(S.squeeze([1]), 1);
255
+ d = H(I.squeeze([1]), L, 1);
256
256
  } else
257
- g = I($.squeeze([1]), 1);
258
- let v;
259
- return o?.includeProbabilities && (v = Z($.squeeze([1]))), g = g.reshape([1, 1]), { output: g, attention: f?.squeeze([1]), probabilities: v };
257
+ d = M(E.squeeze([1]), 1);
258
+ let z;
259
+ return o?.includeProbabilities && (z = J(E.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: g?.squeeze([1]), probabilities: z };
260
260
  });
261
261
  }
262
262
  getNumParams() {
263
- return K(this.config);
263
+ return K(this.config.gpt);
264
264
  }
265
265
  dispose() {
266
266
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
267
267
  }
268
268
  }
269
269
  export {
270
- kt as default
270
+ wt as default
271
271
  };
@@ -1,17 +1,17 @@
1
- import { defaultConfig as d } from "./config.js";
2
- import h from "./NanoGPTModel.js";
3
- import { saveModel as l } from "./utilities/save.js";
1
+ import { defaultConfig as h } from "./config.js";
2
+ import l from "./NanoGPTModel.js";
3
+ import { saveModel as d } from "./utilities/save.js";
4
4
  import { loadModel as f } from "./utilities/load.js";
5
5
  import u from "./Generator.js";
6
6
  import _ from "./Trainer.js";
7
- import { E as c } from "./index-Dwqa6Zy2.js";
7
+ import { E as p } from "./index-Dwqa6Zy2.js";
8
8
  import { dummyPassAsync as m } from "./utilities/dummy.js";
9
- import p from "./tokeniser/CharTokeniser.js";
9
+ import c from "./tokeniser/CharTokeniser.js";
10
10
  import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index-pWA4_lUh.js";
14
+ import "./index-XjBAhiFO.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";
@@ -28,9 +28,12 @@ import "./ops/webgl/rope.js";
28
28
  import "./ops/grads/rope.js";
29
29
  import "./ops/cpu/appendCache.js";
30
30
  import "./ops/webgl/appendCache.js";
31
+ import "./ops/cpu/fusedSoftmax.js";
32
+ import "./ops/webgl/fusedSoftmax.js";
33
+ import "./ops/grads/fusedSoftmax.js";
31
34
  import w from "./utilities/profile.js";
32
35
  class a {
33
- ee = new c();
36
+ ee = new p();
34
37
  _config;
35
38
  _model;
36
39
  _tokeniser;
@@ -47,7 +50,7 @@ class a {
47
50
  get config() {
48
51
  if (!this._config)
49
52
  throw new Error("Model configuration is not initialized.");
50
- return this._config;
53
+ return this._config.gpt;
51
54
  }
52
55
  get model() {
53
56
  if (!this._model)
@@ -71,7 +74,7 @@ class a {
71
74
  saveModel(t) {
72
75
  if (!this._model || !this._tokeniser)
73
76
  throw new Error("Model or tokeniser is not initialized.");
74
- return l(this._model, this._tokeniser, t);
77
+ return d(this._model, this._tokeniser, t);
75
78
  }
76
79
  static loadModel(t) {
77
80
  const e = new a();
@@ -86,7 +89,7 @@ class a {
86
89
  }), e;
87
90
  }
88
91
  static create(t, e = {}) {
89
- const i = { ...d, ...e }, o = t === "char" ? new p(i.vocabSize) : new g(i.vocabSize), s = new h(i), r = new a(o, s);
92
+ const i = { ...h, ...e }, o = t === "char" ? new c(i.vocabSize) : new g(i.vocabSize), s = new l(i), r = new a(o, s);
90
93
  return r.setStatus("warmup"), m(s).then(() => {
91
94
  r.tokeniser.trained ? (r.setStatus("ready"), r.ee.emit("loaded")) : (r.setStatus("awaitingTokens"), r.ee.emit("loaded"), r.tokeniser.once("trainStatus", (n) => {
92
95
  n === "trained" && r.setStatus("ready");
@@ -103,11 +106,11 @@ class a {
103
106
  }
104
107
  set enableProfiler(t) {
105
108
  if (t) {
106
- if (!this._model)
109
+ if (!this._config)
107
110
  throw new Error("Model is not initialized.");
108
- this._model.getProfiler() || this._model.setProfiler(new w());
111
+ this._config.layerConfig.profiler || (this._config.layerConfig.profiler = new w());
109
112
  } else
110
- this._model && this._model.setProfiler(void 0);
113
+ this._config?.layerConfig.profiler && (this._config.layerConfig.profiler = void 0);
111
114
  }
112
115
  getNumParams() {
113
116
  if (!this._model)
@@ -0,0 +1,69 @@
1
+ import { j as c } from "./index-XjBAhiFO.js";
2
+ /**
3
+ * @license
4
+ * Copyright 2017 Google LLC. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ * =============================================================================
17
+ */
18
+ function i(e, n) {
19
+ for (let t = 0; t < e.length; ++t)
20
+ if (e[e.length - t - 1] !== n - 1 - t)
21
+ return !1;
22
+ return !0;
23
+ }
24
+ function p(e, n, t) {
25
+ const r = e.length + n.length, s = [];
26
+ let o = 0, f = 0;
27
+ for (let u = 0; u < r; u++)
28
+ t.indexOf(u) === -1 ? s.push(e[o++]) : s.push(n[f++]);
29
+ return s;
30
+ }
31
+ function a(e, n) {
32
+ const t = [], r = e.length;
33
+ for (let o = 0; o < r; o++)
34
+ n.indexOf(o) === -1 && t.push(e[o]);
35
+ const s = n.map((o) => e[o]);
36
+ return [t, s];
37
+ }
38
+ function m(e, n) {
39
+ const t = n.map((r) => 1);
40
+ return p(e, t, n);
41
+ }
42
+ function d(e, n, t) {
43
+ c(i(n, t), () => `${e} supports only inner-most axes for now. Got axes ${n} and rank-${t} input.`);
44
+ }
45
+ function h(e, n) {
46
+ if (i(e, n))
47
+ return null;
48
+ const t = [];
49
+ for (let r = 0; r < n; ++r)
50
+ e.indexOf(r) === -1 && t.push(r);
51
+ return e.forEach((r) => t.push(r)), t;
52
+ }
53
+ function g(e) {
54
+ return e.map((n, t) => [t, n]).sort((n, t) => n[1] - t[1]).map((n) => n[0]);
55
+ }
56
+ function x(e, n) {
57
+ const t = [];
58
+ for (let r = n - e; r < n; ++r)
59
+ t.push(r);
60
+ return t;
61
+ }
62
+ export {
63
+ x as a,
64
+ d as b,
65
+ a as c,
66
+ g as d,
67
+ m as e,
68
+ h as g
69
+ };
@@ -1,4 +1,4 @@
1
- import { o as s, j as a, i, u as p, E as l, C as u } from "./index-pWA4_lUh.js";
1
+ import { o as s, j as a, i, w as p, E as l, C as f } from "./index-XjBAhiFO.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -15,7 +15,7 @@ import { o as s, j as a, i, u as p, E as l, C as u } from "./index-pWA4_lUh.js";
15
15
  * limitations under the License.
16
16
  * =============================================================================
17
17
  */
18
- function f(o, e = 0) {
18
+ function h(o, e = 0) {
19
19
  a(o.length >= 1, () => "Pass at least one tensor to concat");
20
20
  const t = i(o, "tensors", "concat", "string_or_numeric");
21
21
  if (t[0].dtype === "complex64" && t.forEach((n) => {
@@ -25,9 +25,9 @@ function f(o, e = 0) {
25
25
  }), t.length === 1)
26
26
  return p(t[0]);
27
27
  const r = t, c = { axis: e };
28
- return l.runKernel(u, r, c);
28
+ return l.runKernel(f, r, c);
29
29
  }
30
- const m = /* @__PURE__ */ s({ concat_: f });
30
+ const u = /* @__PURE__ */ s({ concat_: h });
31
31
  export {
32
- m as c
32
+ u as c
33
33
  };