@genai-fi/nanogpt 0.2.12 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. package/dist/Generator.js +30 -25
  2. package/dist/NanoGPTModel.d.ts +13 -14
  3. package/dist/NanoGPTModel.js +142 -70
  4. package/dist/TeachableLLM.d.ts +16 -7
  5. package/dist/TeachableLLM.js +81 -44
  6. package/dist/Trainer.js +8 -8
  7. package/dist/concat-BIZS_td9.js +33 -0
  8. package/dist/data/parquet.js +1 -1
  9. package/dist/exports_layers-tbTBcwMM.js +25 -0
  10. package/dist/{sum-D7fu15XL.js → gather-BPGW8RsB.js} +6 -8
  11. package/dist/index-C4L8Cm77.js +349 -0
  12. package/dist/{index-YPKosni4.js → index-pWA4_lUh.js} +1020 -782
  13. package/dist/layers/CausalSelfAttention.d.ts +11 -11
  14. package/dist/layers/CausalSelfAttention.js +71 -63
  15. package/dist/layers/MLP.d.ts +6 -7
  16. package/dist/layers/MLP.js +18 -16
  17. package/dist/layers/RMSNorm.d.ts +6 -7
  18. package/dist/layers/RMSNorm.js +15 -13
  19. package/dist/layers/RoPECache.d.ts +4 -5
  20. package/dist/layers/RoPECache.js +36 -12
  21. package/dist/layers/TiedEmbedding.d.ts +7 -8
  22. package/dist/layers/TiedEmbedding.js +16 -418
  23. package/dist/layers/TransformerBlock.d.ts +8 -9
  24. package/dist/layers/TransformerBlock.js +12 -12
  25. package/dist/main.d.ts +2 -0
  26. package/dist/main.js +35 -21
  27. package/dist/{mat_mul-Bu7bhLms.js → mat_mul-D7_a4KJn.js} +5 -5
  28. package/dist/moments-DfcpfwKi.js +132 -0
  29. package/dist/ones-Cog-G2ag.js +29 -0
  30. package/dist/ops/appendCache.d.ts +2 -0
  31. package/dist/ops/appendCache.js +9 -0
  32. package/dist/ops/attentionMask.d.ts +1 -1
  33. package/dist/ops/attentionMask.js +7 -85
  34. package/dist/ops/cpu/appendCache.d.ts +2 -0
  35. package/dist/ops/cpu/appendCache.js +28 -0
  36. package/dist/ops/cpu/attentionMask.js +18 -0
  37. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  38. package/dist/ops/cpu/gatherSub.js +34 -0
  39. package/dist/ops/cpu/qkv.d.ts +5 -0
  40. package/dist/ops/cpu/qkv.js +38 -0
  41. package/dist/ops/cpu/rope.d.ts +6 -0
  42. package/dist/ops/cpu/rope.js +38 -0
  43. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  44. package/dist/ops/cpu/scatterSub.js +70 -0
  45. package/dist/ops/gatherSub.d.ts +1 -1
  46. package/dist/ops/gatherSub.js +6 -63
  47. package/dist/ops/grads/attentionMask.d.ts +1 -0
  48. package/dist/ops/grads/attentionMask.js +21 -0
  49. package/dist/ops/grads/qkv.d.ts +1 -0
  50. package/dist/ops/grads/qkv.js +20 -0
  51. package/dist/ops/grads/rope.d.ts +1 -0
  52. package/dist/ops/grads/rope.js +14 -0
  53. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  54. package/dist/ops/qkv.d.ts +1 -6
  55. package/dist/ops/qkv.js +7 -124
  56. package/dist/ops/rope.d.ts +0 -5
  57. package/dist/ops/rope.js +7 -151
  58. package/dist/ops/scatterSub.d.ts +1 -1
  59. package/dist/ops/scatterSub.js +6 -147
  60. package/dist/ops/webgl/appendCache.d.ts +1 -0
  61. package/dist/ops/webgl/appendCache.js +43 -0
  62. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  63. package/dist/ops/webgl/attentionMask.js +43 -0
  64. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  65. package/dist/ops/webgl/gatherSub.js +27 -0
  66. package/dist/ops/webgl/qkv.d.ts +1 -0
  67. package/dist/ops/webgl/qkv.js +46 -0
  68. package/dist/ops/webgl/rope.d.ts +1 -0
  69. package/dist/ops/webgl/rope.js +56 -0
  70. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  71. package/dist/ops/webgl/scatterSub.js +27 -0
  72. package/dist/{parquet-BRl5lE_I.js → parquet-C0Tlmv9c.js} +3045 -3048
  73. package/dist/random_width-oeUIlUZj.js +15487 -0
  74. package/dist/range-CcDl05lo.js +26 -0
  75. package/dist/{reshape-DmnmKT6r.js → reshape-C8CR_Bad.js} +3 -3
  76. package/dist/sin-BJIrfnj7.js +47 -0
  77. package/dist/softmax-Be_lsqUc.js +105 -0
  78. package/dist/{complex-CJ-qCcLB.js → split-DZbvruEP.js} +6 -8
  79. package/dist/stack-BMm-efee.js +27 -0
  80. package/dist/sum-C7Mgy9Bw.js +104 -0
  81. package/dist/tensor-DJVbYhh1.js +24 -0
  82. package/dist/tensor2d-ZuQSh2D-.js +30 -0
  83. package/dist/tokeniser/bpe.d.ts +17 -6
  84. package/dist/tokeniser/bpe.js +89 -61
  85. package/dist/training/AdamExt.js +1 -1
  86. package/dist/training/DatasetBuilder.d.ts +6 -6
  87. package/dist/training/DatasetBuilder.js +1262 -17
  88. package/dist/training/Evaluator.d.ts +3 -2
  89. package/dist/training/FullTrainer.d.ts +9 -8
  90. package/dist/training/FullTrainer.js +26 -25
  91. package/dist/training/LayerTrainer.d.ts +9 -8
  92. package/dist/training/LayerTrainer.js +34 -33
  93. package/dist/training/Trainer.d.ts +22 -21
  94. package/dist/training/Trainer.js +21 -18
  95. package/dist/training/sparseCrossEntropy.js +22 -166
  96. package/dist/utilities/dummy.js +10 -8
  97. package/dist/utilities/generate.js +14 -11
  98. package/dist/utilities/load.d.ts +1 -2
  99. package/dist/utilities/load.js +37 -35
  100. package/dist/utilities/profile.js +1 -1
  101. package/dist/utilities/save.js +14 -9
  102. package/dist/utilities/tokenParse.d.ts +1 -1
  103. package/dist/utilities/tokenParse.js +7 -61
  104. package/dist/utilities/weights.d.ts +3 -3
  105. package/dist/utilities/weights.js +21 -19
  106. package/dist/variable-Dl_ub3pk.js +23 -0
  107. package/dist/{stack-BtKpB0Ry.js → zeros-CCy9C3uU.js} +18 -16
  108. package/package.json +2 -1
  109. package/dist/assets/worker-BYeSPNkq.js +0 -1
  110. package/dist/tokeniser/NodeTokeniser.d.ts +0 -20
  111. package/dist/tokeniser/NodeTokeniser.js +0 -46
  112. package/dist/tokeniser/WebTokeniser.d.ts +0 -18
  113. package/dist/tokeniser/WebTokeniser.js +0 -96
  114. package/dist/tokeniser/worker.js +0 -53
  115. /package/dist/{tokeniser/worker.d.ts → ops/cpu/attentionMask.d.ts} +0 -0
package/dist/Generator.js CHANGED
@@ -1,49 +1,52 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- class f extends u {
2
+ import "./index-pWA4_lUh.js";
3
+ import { t as d } from "./tensor2d-ZuQSh2D-.js";
4
+ import { c as k } from "./concat-BIZS_td9.js";
5
+ class w extends u {
3
6
  constructor(s, e) {
4
7
  super(), this.model = s, this.tokeniser = e;
5
8
  }
6
9
  active = !1;
7
10
  async tokenisePrompt(s) {
8
11
  const e = s ? await this.tokeniser.tokenise([s], !0) : [[this.tokeniser.eosToken]];
9
- return this.model.tf.tensor2d(e, [1, e[0].length], "int32");
12
+ return d(e, [1, e[0].length], "int32");
10
13
  }
11
14
  async generateNoCache(s, e) {
12
- let t = await this.tokenisePrompt(s), i = s || "";
13
- const o = e?.maxLength ?? 1e3;
14
- for (let a = 0; a < o && this.active; a++) {
15
+ let t = await this.tokenisePrompt(s), o = s || "";
16
+ const n = e?.maxLength ?? 1e3;
17
+ for (let a = 0; a < n && this.active; a++) {
15
18
  const {
16
- output: n,
19
+ output: i,
17
20
  attention: c,
18
21
  probabilities: l
19
22
  } = this.model.generate(t, void 0, e), h = t;
20
- t = this.model.tf.concat([t, n], 1), h.dispose();
21
- const r = await this.processResponse(n, c, l);
22
- if (n.dispose(), r === null)
23
+ t = k([t, i], 1), h.dispose();
24
+ const r = await this.processResponse(i, c, l);
25
+ if (i.dispose(), r === null)
23
26
  break;
24
- i += r;
27
+ o += r;
25
28
  }
26
- return t.dispose(), i;
29
+ return t.dispose(), o;
27
30
  }
28
31
  async processResponse(s, e, t) {
29
- const i = (await s.array())[0][0];
30
- if (i === this.tokeniser.eosToken)
32
+ const o = (await s.array())[0][0];
33
+ if (o === this.tokeniser.eosToken)
31
34
  return null;
32
- const o = await this.tokeniser.decode([i]);
35
+ const n = await this.tokeniser.decode([o]);
33
36
  let a;
34
37
  e && (a = await e.array(), e.dispose());
35
- let n;
36
- return t && (n = await t.array(), t.dispose()), this.emit("tokens", [i], o, a, n), o;
38
+ let i;
39
+ return t && (i = await t.array(), t.dispose()), this.emit("tokens", [o], n, a, i), n;
37
40
  }
38
41
  async generateCache(s, e) {
39
- let t = await this.tokenisePrompt(s), i = s || "";
40
- const o = new Array(this.model.config.nLayer).fill(void 0), a = e?.maxLength ?? 1e3;
41
- for (let n = 0; n < a && this.active; n++) {
42
+ let t = await this.tokenisePrompt(s), o = s || "";
43
+ const n = new Array(this.model.config.nLayer).fill(void 0), a = e?.maxLength ?? 1e3;
44
+ for (let i = 0; i < a && this.active; i++) {
42
45
  const {
43
46
  output: c,
44
47
  attention: l,
45
48
  probabilities: h
46
- } = this.model.generate(t, o, {
49
+ } = this.model.generate(t, n, {
47
50
  ...e,
48
51
  usePadding: !1
49
52
  });
@@ -51,20 +54,22 @@ class f extends u {
51
54
  const r = await this.processResponse(c, l, h);
52
55
  if (r === null)
53
56
  break;
54
- i += r;
57
+ o += r;
55
58
  }
56
- return t.dispose(), i;
59
+ return n.forEach((i) => {
60
+ i && (i.k.dispose(), i.v.dispose());
61
+ }), t.dispose(), o;
57
62
  }
58
63
  async generate(s, e) {
59
64
  const t = s && s.length > this.model.config.blockSize ? s.slice(-this.model.config.blockSize) : s;
60
65
  this.active = !0, this.emit("start");
61
- const o = await (this.model.config.useRope && !e?.noCache ? this.generateCache(t, e) : this.generateNoCache(t, e));
62
- return this.active = !1, this.emit("stop"), o;
66
+ const n = await (this.model.config.useRope && !e?.noCache ? this.generateCache(t, e) : this.generateNoCache(t, e));
67
+ return this.active = !1, this.emit("stop"), n;
63
68
  }
64
69
  stop() {
65
70
  this.active = !1;
66
71
  }
67
72
  }
68
73
  export {
69
- f as default
74
+ w as default
70
75
  };
@@ -1,8 +1,8 @@
1
- import { default as TF } from '@tensorflow/tfjs';
2
1
  import { GPTConfig } from './config';
3
2
  import { KVCache } from './layers/CausalSelfAttention';
4
3
  import { default as MemoryProfiler } from './utilities/profile';
5
4
  import { default as BaseLayer } from './layers/BaseLayer';
5
+ import { Tensor, Variable } from '@tensorflow/tfjs-core';
6
6
  export interface TrainingLogEntry {
7
7
  loss: number;
8
8
  valLoss?: number;
@@ -26,12 +26,11 @@ export default class NanoGPT extends BaseLayer {
26
26
  private blocks;
27
27
  private lnF;
28
28
  private ropeCache?;
29
- readonly tf: typeof TF;
30
29
  log: TrainingLogEntry[];
31
- constructor(tf: typeof TF, config?: Partial<GPTConfig>);
32
- get variables(): TF.Variable[];
33
- saveWeights(): Map<string, TF.Tensor[]>;
34
- loadWeights(weights: Map<string, TF.Tensor[]>): void;
30
+ constructor(config?: Partial<GPTConfig>);
31
+ get variables(): Variable[];
32
+ saveWeights(): Map<string, Tensor[]>;
33
+ loadWeights(weights: Map<string, Tensor[]>): void;
35
34
  private inputPhase;
36
35
  setSkipMask(mask: boolean[]): void;
37
36
  setTrainableMask(mask: boolean[]): void;
@@ -40,15 +39,15 @@ export default class NanoGPT extends BaseLayer {
40
39
  private validateInput;
41
40
  private calculateLoss;
42
41
  private computeAttentionRollout;
43
- forward(idx: TF.Tensor, targets?: TF.Tensor, training?: boolean, includeAttention?: boolean, cache?: (KVCache | undefined)[]): {
44
- logits: TF.Tensor;
45
- loss?: TF.Tensor;
46
- attention?: TF.Tensor;
42
+ forward(idx: Tensor, targets?: Tensor, training?: boolean, includeAttention?: boolean, cache?: (KVCache | undefined)[]): {
43
+ logits: Tensor;
44
+ loss?: Tensor;
45
+ attention?: Tensor;
47
46
  };
48
- generate(idx: TF.Tensor, cache?: (KVCache | undefined)[], options?: GenerateOptions): {
49
- output: TF.Tensor;
50
- attention?: TF.Tensor;
51
- probabilities?: TF.Tensor;
47
+ generate(idx: Tensor, cache?: (KVCache | undefined)[], options?: GenerateOptions): {
48
+ output: Tensor;
49
+ attention?: Tensor;
50
+ probabilities?: Tensor;
52
51
  };
53
52
  getNumParams(): number;
54
53
  dispose(): void;
@@ -1,12 +1,98 @@
1
- import { defaultConfig as v } from "./config.js";
2
- import z from "./layers/TransformerBlock.js";
3
- import S from "./layers/TiedEmbedding.js";
4
- import _ from "./layers/RoPECache.js";
5
- import I from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as F } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as L } from "./training/sparseCrossEntropy.js";
8
- import P from "./layers/BaseLayer.js";
9
- class A extends P {
1
+ import { defaultConfig as F } from "./config.js";
2
+ import L from "./layers/TransformerBlock.js";
3
+ import P from "./layers/TiedEmbedding.js";
4
+ import C from "./layers/RoPECache.js";
5
+ import q from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as K } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
8
+ import T from "./layers/BaseLayer.js";
9
+ import { r as R, e as D, p as A } from "./random_width-oeUIlUZj.js";
10
+ import { o as y, h as E, p as B, E as z, W as G, X as O, Y as Q, t as w, Z as X, f as _ } from "./index-pWA4_lUh.js";
11
+ import { e as j, a as U } from "./exports_layers-tbTBcwMM.js";
12
+ import { r as S } from "./reshape-C8CR_Bad.js";
13
+ import { r as V } from "./range-CcDl05lo.js";
14
+ import { g as Y } from "./gather-BPGW8RsB.js";
15
+ import { s as Z } from "./softmax-Be_lsqUc.js";
16
+ /**
17
+ * @license
18
+ * Copyright 2020 Google LLC. All Rights Reserved.
19
+ * Licensed under the Apache License, Version 2.0 (the "License");
20
+ * you may not use this file except in compliance with the License.
21
+ * You may obtain a copy of the License at
22
+ *
23
+ * http://www.apache.org/licenses/LICENSE-2.0
24
+ *
25
+ * Unless required by applicable law or agreed to in writing, software
26
+ * distributed under the License is distributed on an "AS IS" BASIS,
27
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ * See the License for the specific language governing permissions and
29
+ * limitations under the License.
30
+ * =============================================================================
31
+ */
32
+ function H(m, t) {
33
+ let e = E(m, "a", "mod"), o = E(t, "b", "mod");
34
+ [e, o] = B(e, o);
35
+ const i = { a: e, b: o };
36
+ return z.runKernel(G, i);
37
+ }
38
+ const J = /* @__PURE__ */ y({ mod_: H });
39
+ /**
40
+ * @license
41
+ * Copyright 2020 Google LLC. All Rights Reserved.
42
+ * Licensed under the Apache License, Version 2.0 (the "License");
43
+ * you may not use this file except in compliance with the License.
44
+ * You may obtain a copy of the License at
45
+ *
46
+ * http://www.apache.org/licenses/LICENSE-2.0
47
+ *
48
+ * Unless required by applicable law or agreed to in writing, software
49
+ * distributed under the License is distributed on an "AS IS" BASIS,
50
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51
+ * See the License for the specific language governing permissions and
52
+ * limitations under the License.
53
+ * =============================================================================
54
+ */
55
+ function tt(m, t, e, o = !1) {
56
+ const i = E(m, "logits", "multinomial"), s = i.size, r = i.rank;
57
+ if (s < 2)
58
+ throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
59
+ if (r > 2)
60
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
61
+ e = e || Math.random();
62
+ const n = { logits: r === 1 ? S(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = z.runKernel(O, n, h);
63
+ return r === 1 ? S(a, [a.size]) : a;
64
+ }
65
+ const I = /* @__PURE__ */ y({ multinomial_: tt });
66
+ /**
67
+ * @license
68
+ * Copyright 2018 Google LLC. All Rights Reserved.
69
+ * Licensed under the Apache License, Version 2.0 (the "License");
70
+ * you may not use this file except in compliance with the License.
71
+ * You may obtain a copy of the License at
72
+ *
73
+ * http://www.apache.org/licenses/LICENSE-2.0
74
+ *
75
+ * Unless required by applicable law or agreed to in writing, software
76
+ * distributed under the License is distributed on an "AS IS" BASIS,
77
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78
+ * See the License for the specific language governing permissions and
79
+ * limitations under the License.
80
+ * =============================================================================
81
+ */
82
+ function et(m, t = 1, e = !0) {
83
+ const o = E(m, "x", "topk");
84
+ if (o.rank === 0)
85
+ throw new Error("topk() expects the input to be of rank 1 or higher");
86
+ const i = o.shape[o.shape.length - 1];
87
+ if (t < 0)
88
+ throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
89
+ if (t > i)
90
+ throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
91
+ const s = { x: o }, r = { k: t, sorted: e }, [l, n] = z.runKernel(Q, s, r);
92
+ return { values: l, indices: n };
93
+ }
94
+ const ot = /* @__PURE__ */ y({ topk_: et });
95
+ class kt extends T {
10
96
  config;
11
97
  wte;
12
98
  // Token embeddings
@@ -18,23 +104,22 @@ class A extends P {
18
104
  lnF;
19
105
  // Final layer norm
20
106
  ropeCache;
21
- tf;
22
107
  log = [];
23
108
  // Training log
24
- constructor(t, e = {}) {
25
- super(), this.tf = t, this.config = { ...v, ...e }, this.wte = new S(t, {
109
+ constructor(t = {}) {
110
+ super(), this.config = { ...F, ...t }, this.wte = new P({
26
111
  vocabSize: this.config.vocabSize,
27
112
  embedDim: this.config.nEmbed,
28
113
  name: "token_embedding"
29
- }), this.config.useRope === !1 ? this.wpe = this.tf.layers.embedding({
114
+ }), this.config.useRope === !1 ? this.wpe = j({
30
115
  inputDim: this.config.blockSize,
31
116
  outputDim: this.config.nEmbed,
32
117
  name: "positional_embedding",
33
- embeddingsInitializer: this.tf.initializers.randomNormal({ mean: 0, stddev: 0.02 })
34
- }) : this.ropeCache = new _(t, this.config), this.drop = this.tf.layers.dropout({ rate: this.config.dropout }), this.blocks = [];
35
- for (let o = 0; o < this.config.nLayer; o++)
36
- this.blocks.push(new z(this.tf, o, this.config, this.ropeCache));
37
- this.lnF = new I(t, [this.config.nEmbed], 1e-8, "final_rms_norm");
118
+ embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
119
+ }) : this.ropeCache = new C(this.config), this.drop = U({ rate: this.config.dropout }), this.blocks = [];
120
+ for (let e = 0; e < this.config.nLayer; e++)
121
+ this.blocks.push(new L(e, this.config, this.ropeCache));
122
+ this.lnF = new q([this.config.nEmbed], 1e-8, "final_rms_norm");
38
123
  }
39
124
  get variables() {
40
125
  return [
@@ -58,14 +143,11 @@ class A extends P {
58
143
  this.lnF.setWeights(t.get("final_rms_norm") || []);
59
144
  }
60
145
  inputPhase(t, e, o = !1) {
61
- return this.tf.tidy(() => {
146
+ return w(() => {
62
147
  const i = this.wte.embed(t);
63
148
  if (this.config.useRope === !1) {
64
- const [, s] = t.shape, l = this.config.blockSize, r = this.tf.range(0, s, 1, "int32"), n = this.tf.mod(
65
- this.tf.add(r, this.tf.scalar(e, "int32")),
66
- this.tf.scalar(l, "int32")
67
- ), h = this.wpe.apply(n), c = i.add(h);
68
- return this.drop.apply(c, { training: o });
149
+ const [, s] = t.shape, r = this.config.blockSize, l = V(0, s, 1, "int32"), n = J(X(l, _(e, "int32")), _(r, "int32")), h = this.wpe.apply(n), a = i.add(h);
150
+ return this.drop.apply(a, { training: o });
69
151
  } else
70
152
  return this.drop.apply(i, { training: o });
71
153
  });
@@ -103,7 +185,7 @@ class A extends P {
103
185
  }
104
186
  calculateLoss(t, e) {
105
187
  try {
106
- return L()(t, e).mean();
188
+ return N()(t, e).mean();
107
189
  } catch (o) {
108
190
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
109
191
  }
@@ -111,89 +193,79 @@ class A extends P {
111
193
  // Attention rollout per Abnar & Zuidema (2020)
112
194
  // Expects list of (B, T, T) attention matrices already averaged over heads.
113
195
  computeAttentionRollout(t) {
114
- return this.tf.tidy(() => {
196
+ return w(() => {
115
197
  if (t.length === 0)
116
198
  throw new Error("No attentions for rollout");
117
199
  const [e, o, i] = t[0].shape;
118
200
  for (const s of t) {
119
- const [l, r, n] = s.shape;
120
- if (l !== e || r !== o || n !== i)
201
+ const [r, l, n] = s.shape;
202
+ if (r !== e || l !== o || n !== i)
121
203
  throw new Error(
122
- `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${l},${r},${n}]`
204
+ `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${r},${l},${n}]`
123
205
  );
124
206
  }
125
207
  if (o === i) {
126
- const s = this.tf.eye(i, i).expandDims(0);
127
- let l = s.tile([e, 1, 1]);
128
- for (const r of t) {
129
- const n = r.add(s);
130
- l = n.div(n.sum(-1, !0)).matMul(l);
131
- }
132
- return l;
133
- }
134
- if (o === 1) {
135
- let s = null;
136
- const l = this.tf.tensor1d([i - 1], "int32"), r = this.tf.oneHot(l, i).reshape([1, 1, i]).tile([e, 1, 1]);
137
- l.dispose();
138
- for (const n of t) {
139
- let h = n.add(r);
140
- h = h.div(h.sum(-1, !0)), s == null ? s = h : (s = s.mul(h), s = s.div(s.sum(-1, !0)));
208
+ const s = D(i, i).expandDims(0);
209
+ let r = s.tile([e, 1, 1]);
210
+ for (const l of t) {
211
+ const n = l.add(s);
212
+ r = n.div(n.sum(-1, !0)).matMul(r);
141
213
  }
142
- return s;
214
+ return r;
143
215
  }
144
216
  throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${i}]`);
145
217
  });
146
218
  }
147
219
  forward(t, e, o = !1, i = !1, s) {
148
- return this.validateInput(t), this.tf.tidy(() => {
220
+ return this.validateInput(t), w(() => {
149
221
  this.startMemory();
150
- const l = s?.[0]?.length ?? 0;
151
- let r = this.inputPhase(t, l, o);
222
+ const r = s?.[0]?.length ?? 0;
223
+ let l = this.inputPhase(t, r, o);
152
224
  const n = [];
153
225
  if (s && s.length !== this.blocks.length)
154
226
  throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
155
- for (let a = 0; a < this.blocks.length; a++) {
156
- const d = r, g = this.blocks[a], {
157
- output: m,
158
- attention: b,
227
+ for (let c = 0; c < this.blocks.length; c++) {
228
+ const u = l, d = this.blocks[c], {
229
+ output: b,
230
+ attention: k,
159
231
  cache: f
160
- } = g.call(r, o, i, s ? s[a] : void 0);
161
- r = m, d.dispose(), i && b && n.push(b), s && f ? (s[a]?.k.dispose(), s[a]?.v.dispose(), s[a] = f) : f && (f.k.dispose(), f.v.dispose());
232
+ } = d.call(l, o, i, s ? s[c] : void 0);
233
+ l = b, u.dispose(), i && k && n.push(k), s && f ? (s[c]?.k.dispose(), s[c]?.v.dispose(), s[c] = f) : f && (f.k.dispose(), f.v.dispose());
162
234
  }
163
235
  let h;
164
- i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
165
- const c = this.wte.project(r);
236
+ i && n.length > 0 && (h = this.computeAttentionRollout(n)), l = this.lnF.apply(l);
237
+ const a = this.wte.project(l);
166
238
  let p;
167
- return e && (p = this.calculateLoss(c, e)), this.endMemory("Forward"), { logits: c, loss: p, attention: i ? h : void 0 };
239
+ return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
168
240
  });
169
241
  }
170
242
  generate(t, e, o) {
171
- const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
172
- return this.tf.tidy(() => {
173
- const n = t, h = n.shape[1], c = h <= this.config.blockSize ? n : n.slice(
243
+ const i = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, l = o?.includeAttention ?? !1;
244
+ return w(() => {
245
+ const n = t, h = n.shape[1], a = h <= this.config.blockSize ? n : n.slice(
174
246
  [0, h - this.config.blockSize],
175
247
  [n.shape[0], this.config.blockSize]
176
- ), p = l ? this.config.blockSize - c.shape[1] : 0, a = p > 0 ? this.tf.pad(c, [
248
+ ), p = r ? this.config.blockSize - a.shape[1] : 0, c = p > 0 ? A(a, [
177
249
  [0, 0],
178
250
  [0, p]
179
- ]) : c, { logits: d, attention: g } = this.forward(a, void 0, !1, r, e), m = d.shape[1] - 1 - p, b = d.slice([0, m, 0], [d.shape[0], 1, d.shape[2]]), f = g ? g.slice([0, m, 0], [g.shape[0], 1, g.shape[2]]) : void 0, k = b.div(i);
180
- let u;
251
+ ]) : a, { logits: u, attention: d } = this.forward(c, void 0, !1, l, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = d ? d.slice([0, b, 0], [d.shape[0], 1, d.shape[2]]) : void 0, $ = k.div(i);
252
+ let g;
181
253
  if (s) {
182
- const { values: y, indices: E } = this.tf.topk(k, s), $ = this.tf.multinomial(y.squeeze([1]), 1);
183
- u = this.tf.gather(E.squeeze([1]), $, 1);
254
+ const { values: M, indices: x } = ot($, s), W = I(M.squeeze([1]), 1);
255
+ g = Y(x.squeeze([1]), W, 1);
184
256
  } else
185
- u = this.tf.multinomial(k.squeeze([1]), 1);
186
- let w;
187
- return o?.includeProbabilities && (w = this.tf.softmax(k.squeeze([1]))), u = u.reshape([1, 1]), { output: u, attention: f?.squeeze([1]), probabilities: w };
257
+ g = I($.squeeze([1]), 1);
258
+ let v;
259
+ return o?.includeProbabilities && (v = Z($.squeeze([1]))), g = g.reshape([1, 1]), { output: g, attention: f?.squeeze([1]), probabilities: v };
188
260
  });
189
261
  }
190
262
  getNumParams() {
191
- return F(this.config);
263
+ return K(this.config);
192
264
  }
193
265
  dispose() {
194
266
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
195
267
  }
196
268
  }
197
269
  export {
198
- A as default
270
+ kt as default
199
271
  };
@@ -1,20 +1,20 @@
1
- import { default as TF } from '@tensorflow/tfjs';
2
1
  import { GPTConfig } from './config';
3
2
  import { ITokeniser } from './tokeniser/type';
4
3
  import { default as NanoGPT } from './NanoGPTModel';
5
4
  import { SaveOptions } from './utilities/save';
6
5
  import { default as Generator, IGenerateOptions } from './Generator';
7
6
  import { default as Trainer, ITrainerOptions } from './Trainer';
8
- import { default as EE } from 'eventemitter3';
9
7
  import { default as MemoryProfiler } from './utilities/profile';
10
8
  type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
11
- export default class TeachableLLM extends EE<'status' | 'error' | 'trainStep'> {
9
+ export default class TeachableLLM {
10
+ private ee;
12
11
  private _config?;
13
12
  private _model?;
14
- readonly tf: typeof TF;
15
13
  private _tokeniser?;
16
14
  private _status;
17
- constructor(tf: typeof TF, tokeniser?: ITokeniser, model?: NanoGPT);
15
+ constructor(tokeniser?: ITokeniser, model?: NanoGPT);
16
+ get vocab(): string[];
17
+ get loaded(): boolean;
18
18
  get config(): GPTConfig;
19
19
  get model(): NanoGPT;
20
20
  get tokeniser(): ITokeniser;
@@ -22,16 +22,25 @@ export default class TeachableLLM extends EE<'status' | 'error' | 'trainStep'> {
22
22
  get ready(): boolean;
23
23
  private setStatus;
24
24
  saveModel(options?: SaveOptions): Promise<Blob>;
25
- static loadModel(tf: typeof TF, data: Blob | Buffer | string): TeachableLLM;
26
- static create(tf: typeof TF, config?: Partial<GPTConfig>): TeachableLLM;
25
+ static loadModel(data: Blob | Buffer | string): TeachableLLM;
26
+ static create(tokeniserType: 'char' | 'bpe', config?: Partial<GPTConfig>): TeachableLLM;
27
27
  getProfiler(): MemoryProfiler | undefined;
28
28
  get enableProfiler(): boolean;
29
29
  set enableProfiler(value: boolean);
30
30
  getNumParams(): number;
31
31
  trainer(): Trainer;
32
32
  train(text: string[], options?: ITrainerOptions): Promise<void>;
33
+ trainTokeniser(text: string[]): Promise<number>;
33
34
  generator(): Generator;
34
35
  generateText(prompt?: string, options?: IGenerateOptions): Promise<string>;
35
36
  dispose(): void;
37
+ on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
38
+ on(event: 'error', listener: (error: Error) => void): void;
39
+ on(event: 'trainStep', listener: (step: number) => void): void;
40
+ on(event: 'loaded', listener: () => void): void;
41
+ off(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
42
+ off(event: 'error', listener: (error: Error) => void): void;
43
+ off(event: 'trainStep', listener: (step: number) => void): void;
44
+ off(event: 'loaded', listener: () => void): void;
36
45
  }
37
46
  export {};