@genai-fi/nanogpt 0.2.11 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. package/dist/Generator.js +30 -25
  2. package/dist/NanoGPTModel.d.ts +13 -14
  3. package/dist/NanoGPTModel.js +167 -85
  4. package/dist/TeachableLLM.d.ts +3 -5
  5. package/dist/TeachableLLM.js +47 -35
  6. package/dist/Trainer.js +8 -8
  7. package/dist/concat-BIZS_td9.js +33 -0
  8. package/dist/data/parquet.js +1 -1
  9. package/dist/exports_layers-7idKoYqh.js +25 -0
  10. package/dist/{sum-D7fu15XL.js → gather-BPGW8RsB.js} +6 -8
  11. package/dist/index-C4L8Cm77.js +349 -0
  12. package/dist/{index-YPKosni4.js → index-pWA4_lUh.js} +1020 -782
  13. package/dist/layers/CausalSelfAttention.d.ts +11 -11
  14. package/dist/layers/CausalSelfAttention.js +71 -63
  15. package/dist/layers/MLP.d.ts +6 -7
  16. package/dist/layers/MLP.js +18 -16
  17. package/dist/layers/RMSNorm.d.ts +6 -7
  18. package/dist/layers/RMSNorm.js +15 -13
  19. package/dist/layers/RoPECache.d.ts +4 -6
  20. package/dist/layers/RoPECache.js +36 -23
  21. package/dist/layers/TiedEmbedding.d.ts +7 -8
  22. package/dist/layers/TiedEmbedding.js +16 -418
  23. package/dist/layers/TransformerBlock.d.ts +8 -9
  24. package/dist/layers/TransformerBlock.js +12 -12
  25. package/dist/main.d.ts +1 -0
  26. package/dist/main.js +35 -21
  27. package/dist/{mat_mul-Bu7bhLms.js → mat_mul-D7_a4KJn.js} +5 -5
  28. package/dist/moments-DfcpfwKi.js +132 -0
  29. package/dist/ones-Cog-G2ag.js +29 -0
  30. package/dist/ops/appendCache.d.ts +2 -0
  31. package/dist/ops/appendCache.js +9 -0
  32. package/dist/ops/attentionMask.d.ts +1 -1
  33. package/dist/ops/attentionMask.js +7 -85
  34. package/dist/ops/cpu/appendCache.d.ts +2 -0
  35. package/dist/ops/cpu/appendCache.js +28 -0
  36. package/dist/ops/cpu/attentionMask.js +18 -0
  37. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  38. package/dist/ops/cpu/gatherSub.js +34 -0
  39. package/dist/ops/cpu/qkv.d.ts +5 -0
  40. package/dist/ops/cpu/qkv.js +38 -0
  41. package/dist/ops/cpu/rope.d.ts +6 -0
  42. package/dist/ops/cpu/rope.js +38 -0
  43. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  44. package/dist/ops/cpu/scatterSub.js +70 -0
  45. package/dist/ops/gatherSub.d.ts +1 -1
  46. package/dist/ops/gatherSub.js +6 -63
  47. package/dist/ops/grads/attentionMask.d.ts +1 -0
  48. package/dist/ops/grads/attentionMask.js +21 -0
  49. package/dist/ops/grads/qkv.d.ts +1 -0
  50. package/dist/ops/grads/qkv.js +20 -0
  51. package/dist/ops/grads/rope.d.ts +1 -0
  52. package/dist/ops/grads/rope.js +14 -0
  53. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  54. package/dist/ops/qkv.d.ts +1 -6
  55. package/dist/ops/qkv.js +7 -124
  56. package/dist/ops/rope.d.ts +0 -5
  57. package/dist/ops/rope.js +7 -150
  58. package/dist/ops/scatterSub.d.ts +1 -1
  59. package/dist/ops/scatterSub.js +6 -147
  60. package/dist/ops/webgl/appendCache.d.ts +1 -0
  61. package/dist/ops/webgl/appendCache.js +43 -0
  62. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  63. package/dist/ops/webgl/attentionMask.js +43 -0
  64. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  65. package/dist/ops/webgl/gatherSub.js +27 -0
  66. package/dist/ops/webgl/qkv.d.ts +1 -0
  67. package/dist/ops/webgl/qkv.js +46 -0
  68. package/dist/ops/webgl/rope.d.ts +1 -0
  69. package/dist/ops/webgl/rope.js +56 -0
  70. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  71. package/dist/ops/webgl/scatterSub.js +27 -0
  72. package/dist/{parquet-BRl5lE_I.js → parquet-C0Tlmv9c.js} +3045 -3048
  73. package/dist/random_width-PbCt7RXv.js +15489 -0
  74. package/dist/range-CcDl05lo.js +26 -0
  75. package/dist/{reshape-DmnmKT6r.js → reshape-C8CR_Bad.js} +3 -3
  76. package/dist/sin-BJIrfnj7.js +47 -0
  77. package/dist/softmax-Be_lsqUc.js +105 -0
  78. package/dist/{complex-CJ-qCcLB.js → split-DZbvruEP.js} +6 -8
  79. package/dist/stack-BMm-efee.js +27 -0
  80. package/dist/sum-C7Mgy9Bw.js +104 -0
  81. package/dist/tensor-DJVbYhh1.js +24 -0
  82. package/dist/tensor2d-ZuQSh2D-.js +30 -0
  83. package/dist/tokeniser/bpe.d.ts +17 -6
  84. package/dist/tokeniser/bpe.js +88 -60
  85. package/dist/training/AdamExt.js +1 -1
  86. package/dist/training/DatasetBuilder.d.ts +6 -6
  87. package/dist/training/DatasetBuilder.js +1262 -17
  88. package/dist/training/Evaluator.d.ts +3 -2
  89. package/dist/training/FullTrainer.d.ts +9 -8
  90. package/dist/training/FullTrainer.js +26 -25
  91. package/dist/training/LayerTrainer.d.ts +9 -8
  92. package/dist/training/LayerTrainer.js +34 -33
  93. package/dist/training/Trainer.d.ts +22 -21
  94. package/dist/training/Trainer.js +21 -18
  95. package/dist/training/sparseCrossEntropy.js +22 -166
  96. package/dist/utilities/dummy.js +10 -8
  97. package/dist/utilities/generate.js +14 -11
  98. package/dist/utilities/load.d.ts +1 -2
  99. package/dist/utilities/load.js +37 -35
  100. package/dist/utilities/profile.js +1 -1
  101. package/dist/utilities/save.js +14 -9
  102. package/dist/utilities/tokenParse.d.ts +1 -1
  103. package/dist/utilities/tokenParse.js +7 -61
  104. package/dist/utilities/weights.d.ts +3 -3
  105. package/dist/utilities/weights.js +21 -19
  106. package/dist/variable-Dl_ub3pk.js +23 -0
  107. package/dist/{stack-BtKpB0Ry.js → zeros-CCy9C3uU.js} +18 -16
  108. package/package.json +2 -1
  109. package/dist/assets/worker-BYeSPNkq.js +0 -1
  110. package/dist/tokeniser/NodeTokeniser.d.ts +0 -20
  111. package/dist/tokeniser/NodeTokeniser.js +0 -46
  112. package/dist/tokeniser/WebTokeniser.d.ts +0 -18
  113. package/dist/tokeniser/WebTokeniser.js +0 -96
  114. package/dist/tokeniser/worker.js +0 -53
  115. /package/dist/{tokeniser/worker.d.ts → ops/cpu/attentionMask.d.ts} +0 -0
package/dist/Generator.js CHANGED
@@ -1,49 +1,52 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- class f extends u {
2
+ import "./index-pWA4_lUh.js";
3
+ import { t as d } from "./tensor2d-ZuQSh2D-.js";
4
+ import { c as k } from "./concat-BIZS_td9.js";
5
+ class w extends u {
3
6
  constructor(s, e) {
4
7
  super(), this.model = s, this.tokeniser = e;
5
8
  }
6
9
  active = !1;
7
10
  async tokenisePrompt(s) {
8
11
  const e = s ? await this.tokeniser.tokenise([s], !0) : [[this.tokeniser.eosToken]];
9
- return this.model.tf.tensor2d(e, [1, e[0].length], "int32");
12
+ return d(e, [1, e[0].length], "int32");
10
13
  }
11
14
  async generateNoCache(s, e) {
12
- let t = await this.tokenisePrompt(s), i = s || "";
13
- const o = e?.maxLength ?? 1e3;
14
- for (let a = 0; a < o && this.active; a++) {
15
+ let t = await this.tokenisePrompt(s), o = s || "";
16
+ const n = e?.maxLength ?? 1e3;
17
+ for (let a = 0; a < n && this.active; a++) {
15
18
  const {
16
- output: n,
19
+ output: i,
17
20
  attention: c,
18
21
  probabilities: l
19
22
  } = this.model.generate(t, void 0, e), h = t;
20
- t = this.model.tf.concat([t, n], 1), h.dispose();
21
- const r = await this.processResponse(n, c, l);
22
- if (n.dispose(), r === null)
23
+ t = k([t, i], 1), h.dispose();
24
+ const r = await this.processResponse(i, c, l);
25
+ if (i.dispose(), r === null)
23
26
  break;
24
- i += r;
27
+ o += r;
25
28
  }
26
- return t.dispose(), i;
29
+ return t.dispose(), o;
27
30
  }
28
31
  async processResponse(s, e, t) {
29
- const i = (await s.array())[0][0];
30
- if (i === this.tokeniser.eosToken)
32
+ const o = (await s.array())[0][0];
33
+ if (o === this.tokeniser.eosToken)
31
34
  return null;
32
- const o = await this.tokeniser.decode([i]);
35
+ const n = await this.tokeniser.decode([o]);
33
36
  let a;
34
37
  e && (a = await e.array(), e.dispose());
35
- let n;
36
- return t && (n = await t.array(), t.dispose()), this.emit("tokens", [i], o, a, n), o;
38
+ let i;
39
+ return t && (i = await t.array(), t.dispose()), this.emit("tokens", [o], n, a, i), n;
37
40
  }
38
41
  async generateCache(s, e) {
39
- let t = await this.tokenisePrompt(s), i = s || "";
40
- const o = new Array(this.model.config.nLayer).fill(void 0), a = e?.maxLength ?? 1e3;
41
- for (let n = 0; n < a && this.active; n++) {
42
+ let t = await this.tokenisePrompt(s), o = s || "";
43
+ const n = new Array(this.model.config.nLayer).fill(void 0), a = e?.maxLength ?? 1e3;
44
+ for (let i = 0; i < a && this.active; i++) {
42
45
  const {
43
46
  output: c,
44
47
  attention: l,
45
48
  probabilities: h
46
- } = this.model.generate(t, o, {
49
+ } = this.model.generate(t, n, {
47
50
  ...e,
48
51
  usePadding: !1
49
52
  });
@@ -51,20 +54,22 @@ class f extends u {
51
54
  const r = await this.processResponse(c, l, h);
52
55
  if (r === null)
53
56
  break;
54
- i += r;
57
+ o += r;
55
58
  }
56
- return t.dispose(), i;
59
+ return n.forEach((i) => {
60
+ i && (i.k.dispose(), i.v.dispose());
61
+ }), t.dispose(), o;
57
62
  }
58
63
  async generate(s, e) {
59
64
  const t = s && s.length > this.model.config.blockSize ? s.slice(-this.model.config.blockSize) : s;
60
65
  this.active = !0, this.emit("start");
61
- const o = await (this.model.config.useRope && !e?.noCache ? this.generateCache(t, e) : this.generateNoCache(t, e));
62
- return this.active = !1, this.emit("stop"), o;
66
+ const n = await (this.model.config.useRope && !e?.noCache ? this.generateCache(t, e) : this.generateNoCache(t, e));
67
+ return this.active = !1, this.emit("stop"), n;
63
68
  }
64
69
  stop() {
65
70
  this.active = !1;
66
71
  }
67
72
  }
68
73
  export {
69
- f as default
74
+ w as default
70
75
  };
@@ -1,8 +1,8 @@
1
- import { default as TF } from '@tensorflow/tfjs';
2
1
  import { GPTConfig } from './config';
3
2
  import { KVCache } from './layers/CausalSelfAttention';
4
3
  import { default as MemoryProfiler } from './utilities/profile';
5
4
  import { default as BaseLayer } from './layers/BaseLayer';
5
+ import { Tensor, Variable } from '@tensorflow/tfjs-core';
6
6
  export interface TrainingLogEntry {
7
7
  loss: number;
8
8
  valLoss?: number;
@@ -26,12 +26,11 @@ export default class NanoGPT extends BaseLayer {
26
26
  private blocks;
27
27
  private lnF;
28
28
  private ropeCache?;
29
- readonly tf: typeof TF;
30
29
  log: TrainingLogEntry[];
31
- constructor(tf: typeof TF, config?: Partial<GPTConfig>);
32
- get variables(): TF.Variable[];
33
- saveWeights(): Map<string, TF.Tensor[]>;
34
- loadWeights(weights: Map<string, TF.Tensor[]>): void;
30
+ constructor(config?: Partial<GPTConfig>);
31
+ get variables(): Variable[];
32
+ saveWeights(): Map<string, Tensor[]>;
33
+ loadWeights(weights: Map<string, Tensor[]>): void;
35
34
  private inputPhase;
36
35
  setSkipMask(mask: boolean[]): void;
37
36
  setTrainableMask(mask: boolean[]): void;
@@ -40,15 +39,15 @@ export default class NanoGPT extends BaseLayer {
40
39
  private validateInput;
41
40
  private calculateLoss;
42
41
  private computeAttentionRollout;
43
- forward(idx: TF.Tensor, targets?: TF.Tensor, training?: boolean, includeAttention?: boolean, cache?: (KVCache | undefined)[]): {
44
- logits: TF.Tensor;
45
- loss?: TF.Tensor;
46
- attention?: TF.Tensor;
42
+ forward(idx: Tensor, targets?: Tensor, training?: boolean, includeAttention?: boolean, cache?: (KVCache | undefined)[]): {
43
+ logits: Tensor;
44
+ loss?: Tensor;
45
+ attention?: Tensor;
47
46
  };
48
- generate(idx: TF.Tensor, cache?: (KVCache | undefined)[], options?: GenerateOptions): {
49
- output: TF.Tensor;
50
- attention?: TF.Tensor;
51
- probabilities?: TF.Tensor;
47
+ generate(idx: Tensor, cache?: (KVCache | undefined)[], options?: GenerateOptions): {
48
+ output: Tensor;
49
+ attention?: Tensor;
50
+ probabilities?: Tensor;
52
51
  };
53
52
  getNumParams(): number;
54
53
  dispose(): void;
@@ -1,12 +1,98 @@
1
- import { defaultConfig as v } from "./config.js";
2
- import z from "./layers/TransformerBlock.js";
3
- import S from "./layers/TiedEmbedding.js";
4
- import _ from "./layers/RoPECache.js";
5
- import I from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as F } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as L } from "./training/sparseCrossEntropy.js";
8
- import P from "./layers/BaseLayer.js";
9
- class A extends P {
1
+ import { defaultConfig as F } from "./config.js";
2
+ import L from "./layers/TransformerBlock.js";
3
+ import P from "./layers/TiedEmbedding.js";
4
+ import C from "./layers/RoPECache.js";
5
+ import q from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as K } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
8
+ import R from "./layers/BaseLayer.js";
9
+ import { r as T, e as D, t as A, o as B, p as G } from "./random_width-PbCt7RXv.js";
10
+ import { o as v, h as E, p as O, E as y, W as H, X as Q, Y as X, t as w, Z as j, f as _ } from "./index-pWA4_lUh.js";
11
+ import { e as U, a as V } from "./exports_layers-7idKoYqh.js";
12
+ import { r as I } from "./reshape-C8CR_Bad.js";
13
+ import { r as Y } from "./range-CcDl05lo.js";
14
+ import { g as Z } from "./gather-BPGW8RsB.js";
15
+ import { s as J } from "./softmax-Be_lsqUc.js";
16
+ /**
17
+ * @license
18
+ * Copyright 2020 Google LLC. All Rights Reserved.
19
+ * Licensed under the Apache License, Version 2.0 (the "License");
20
+ * you may not use this file except in compliance with the License.
21
+ * You may obtain a copy of the License at
22
+ *
23
+ * http://www.apache.org/licenses/LICENSE-2.0
24
+ *
25
+ * Unless required by applicable law or agreed to in writing, software
26
+ * distributed under the License is distributed on an "AS IS" BASIS,
27
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ * See the License for the specific language governing permissions and
29
+ * limitations under the License.
30
+ * =============================================================================
31
+ */
32
+ function tt(m, t) {
33
+ let e = E(m, "a", "mod"), s = E(t, "b", "mod");
34
+ [e, s] = O(e, s);
35
+ const i = { a: e, b: s };
36
+ return y.runKernel(H, i);
37
+ }
38
+ const et = /* @__PURE__ */ v({ mod_: tt });
39
+ /**
40
+ * @license
41
+ * Copyright 2020 Google LLC. All Rights Reserved.
42
+ * Licensed under the Apache License, Version 2.0 (the "License");
43
+ * you may not use this file except in compliance with the License.
44
+ * You may obtain a copy of the License at
45
+ *
46
+ * http://www.apache.org/licenses/LICENSE-2.0
47
+ *
48
+ * Unless required by applicable law or agreed to in writing, software
49
+ * distributed under the License is distributed on an "AS IS" BASIS,
50
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51
+ * See the License for the specific language governing permissions and
52
+ * limitations under the License.
53
+ * =============================================================================
54
+ */
55
+ function ot(m, t, e, s = !1) {
56
+ const i = E(m, "logits", "multinomial"), o = i.size, n = i.rank;
57
+ if (o < 2)
58
+ throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${o}.`);
59
+ if (n > 2)
60
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
61
+ e = e || Math.random();
62
+ const r = { logits: n === 1 ? I(i, [1, -1]) : i }, a = { numSamples: t, seed: e, normalized: s }, h = y.runKernel(Q, r, a);
63
+ return n === 1 ? I(h, [h.size]) : h;
64
+ }
65
+ const S = /* @__PURE__ */ v({ multinomial_: ot });
66
+ /**
67
+ * @license
68
+ * Copyright 2018 Google LLC. All Rights Reserved.
69
+ * Licensed under the Apache License, Version 2.0 (the "License");
70
+ * you may not use this file except in compliance with the License.
71
+ * You may obtain a copy of the License at
72
+ *
73
+ * http://www.apache.org/licenses/LICENSE-2.0
74
+ *
75
+ * Unless required by applicable law or agreed to in writing, software
76
+ * distributed under the License is distributed on an "AS IS" BASIS,
77
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78
+ * See the License for the specific language governing permissions and
79
+ * limitations under the License.
80
+ * =============================================================================
81
+ */
82
+ function st(m, t = 1, e = !0) {
83
+ const s = E(m, "x", "topk");
84
+ if (s.rank === 0)
85
+ throw new Error("topk() expects the input to be of rank 1 or higher");
86
+ const i = s.shape[s.shape.length - 1];
87
+ if (t < 0)
88
+ throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
89
+ if (t > i)
90
+ throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
91
+ const o = { x: s }, n = { k: t, sorted: e }, [l, r] = y.runKernel(X, o, n);
92
+ return { values: l, indices: r };
93
+ }
94
+ const it = /* @__PURE__ */ v({ topk_: st });
95
+ class Et extends R {
10
96
  config;
11
97
  wte;
12
98
  // Token embeddings
@@ -18,23 +104,22 @@ class A extends P {
18
104
  lnF;
19
105
  // Final layer norm
20
106
  ropeCache;
21
- tf;
22
107
  log = [];
23
108
  // Training log
24
- constructor(t, e = {}) {
25
- super(), this.tf = t, this.config = { ...v, ...e }, this.wte = new S(t, {
109
+ constructor(t = {}) {
110
+ super(), this.config = { ...F, ...t }, this.wte = new P({
26
111
  vocabSize: this.config.vocabSize,
27
112
  embedDim: this.config.nEmbed,
28
113
  name: "token_embedding"
29
- }), this.config.useRope === !1 ? this.wpe = this.tf.layers.embedding({
114
+ }), this.config.useRope === !1 ? this.wpe = U({
30
115
  inputDim: this.config.blockSize,
31
116
  outputDim: this.config.nEmbed,
32
117
  name: "positional_embedding",
33
- embeddingsInitializer: this.tf.initializers.randomNormal({ mean: 0, stddev: 0.02 })
34
- }) : this.ropeCache = new _(t, this.config), this.drop = this.tf.layers.dropout({ rate: this.config.dropout }), this.blocks = [];
35
- for (let o = 0; o < this.config.nLayer; o++)
36
- this.blocks.push(new z(this.tf, o, this.config, this.ropeCache));
37
- this.lnF = new I(t, [this.config.nEmbed], 1e-8, "final_rms_norm");
118
+ embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
119
+ }) : this.ropeCache = new C(this.config), this.drop = V({ rate: this.config.dropout }), this.blocks = [];
120
+ for (let e = 0; e < this.config.nLayer; e++)
121
+ this.blocks.push(new L(e, this.config, this.ropeCache));
122
+ this.lnF = new q([this.config.nEmbed], 1e-8, "final_rms_norm");
38
123
  }
39
124
  get variables() {
40
125
  return [
@@ -57,17 +142,14 @@ class A extends P {
57
142
  this.blocks[e].loadWeights(t);
58
143
  this.lnF.setWeights(t.get("final_rms_norm") || []);
59
144
  }
60
- inputPhase(t, e, o = !1) {
61
- return this.tf.tidy(() => {
145
+ inputPhase(t, e, s = !1) {
146
+ return w(() => {
62
147
  const i = this.wte.embed(t);
63
148
  if (this.config.useRope === !1) {
64
- const [, s] = t.shape, l = this.config.blockSize, r = this.tf.range(0, s, 1, "int32"), n = this.tf.mod(
65
- this.tf.add(r, this.tf.scalar(e, "int32")),
66
- this.tf.scalar(l, "int32")
67
- ), h = this.wpe.apply(n), c = i.add(h);
68
- return this.drop.apply(c, { training: o });
149
+ const [, o] = t.shape, n = this.config.blockSize, l = Y(0, o, 1, "int32"), r = et(j(l, _(e, "int32")), _(n, "int32")), a = this.wpe.apply(r), h = i.add(a);
150
+ return this.drop.apply(h, { training: s });
69
151
  } else
70
- return this.drop.apply(i, { training: o });
152
+ return this.drop.apply(i, { training: s });
71
153
  });
72
154
  }
73
155
  setSkipMask(t) {
@@ -103,97 +185,97 @@ class A extends P {
103
185
  }
104
186
  calculateLoss(t, e) {
105
187
  try {
106
- return L()(t, e).mean();
107
- } catch (o) {
108
- throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
188
+ return N()(t, e).mean();
189
+ } catch (s) {
190
+ throw console.error("Error computing loss:", s), new Error(`Loss computation failed: ${s}`);
109
191
  }
110
192
  }
111
193
  // Attention rollout per Abnar & Zuidema (2020)
112
194
  // Expects list of (B, T, T) attention matrices already averaged over heads.
113
195
  computeAttentionRollout(t) {
114
- return this.tf.tidy(() => {
196
+ return w(() => {
115
197
  if (t.length === 0)
116
198
  throw new Error("No attentions for rollout");
117
- const [e, o, i] = t[0].shape;
118
- for (const s of t) {
119
- const [l, r, n] = s.shape;
120
- if (l !== e || r !== o || n !== i)
199
+ const [e, s, i] = t[0].shape;
200
+ for (const o of t) {
201
+ const [n, l, r] = o.shape;
202
+ if (n !== e || l !== s || r !== i)
121
203
  throw new Error(
122
- `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${l},${r},${n}]`
204
+ `Inconsistent attention shapes in rollout: expected [${e},${s},${i}] got [${n},${l},${r}]`
123
205
  );
124
206
  }
125
- if (o === i) {
126
- const s = this.tf.eye(i, i).expandDims(0);
127
- let l = s.tile([e, 1, 1]);
128
- for (const r of t) {
129
- const n = r.add(s);
130
- l = n.div(n.sum(-1, !0)).matMul(l);
207
+ if (s === i) {
208
+ const o = D(i, i).expandDims(0);
209
+ let n = o.tile([e, 1, 1]);
210
+ for (const l of t) {
211
+ const r = l.add(o);
212
+ n = r.div(r.sum(-1, !0)).matMul(n);
131
213
  }
132
- return l;
214
+ return n;
133
215
  }
134
- if (o === 1) {
135
- let s = null;
136
- const l = this.tf.tensor1d([i - 1], "int32"), r = this.tf.oneHot(l, i).reshape([1, 1, i]).tile([e, 1, 1]);
137
- l.dispose();
138
- for (const n of t) {
139
- let h = n.add(r);
140
- h = h.div(h.sum(-1, !0)), s == null ? s = h : (s = s.mul(h), s = s.div(s.sum(-1, !0)));
216
+ if (s === 1) {
217
+ let o = null;
218
+ const n = A([i - 1], "int32"), l = B(n, i).reshape([1, 1, i]).tile([e, 1, 1]);
219
+ n.dispose();
220
+ for (const r of t) {
221
+ let a = r.add(l);
222
+ a = a.div(a.sum(-1, !0)), o == null ? o = a : (o = o.mul(a), o = o.div(o.sum(-1, !0)));
141
223
  }
142
- return s;
224
+ return o;
143
225
  }
144
- throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${i}]`);
226
+ throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${s}, K=${i}]`);
145
227
  });
146
228
  }
147
- forward(t, e, o = !1, i = !1, s) {
148
- return this.validateInput(t), this.tf.tidy(() => {
229
+ forward(t, e, s = !1, i = !1, o) {
230
+ return this.validateInput(t), w(() => {
149
231
  this.startMemory();
150
- const l = s?.[0]?.length ?? 0;
151
- let r = this.inputPhase(t, l, o);
152
- const n = [];
153
- if (s && s.length !== this.blocks.length)
154
- throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
155
- for (let a = 0; a < this.blocks.length; a++) {
156
- const d = r, g = this.blocks[a], {
157
- output: m,
158
- attention: b,
232
+ const n = o?.[0]?.length ?? 0;
233
+ let l = this.inputPhase(t, n, s);
234
+ const r = [];
235
+ if (o && o.length !== this.blocks.length)
236
+ throw console.error("Cache", o), new Error(`Cache length ${o.length} does not match number of blocks ${this.blocks.length}`);
237
+ for (let c = 0; c < this.blocks.length; c++) {
238
+ const u = l, d = this.blocks[c], {
239
+ output: b,
240
+ attention: k,
159
241
  cache: f
160
- } = g.call(r, o, i, s ? s[a] : void 0);
161
- r = m, d.dispose(), i && b && n.push(b), s && f ? (s[a]?.k.dispose(), s[a]?.v.dispose(), s[a] = f) : f && (f.k.dispose(), f.v.dispose());
242
+ } = d.call(l, s, i, o ? o[c] : void 0);
243
+ l = b, u.dispose(), i && k && r.push(k), o && f ? (o[c]?.k.dispose(), o[c]?.v.dispose(), o[c] = f) : f && (f.k.dispose(), f.v.dispose());
162
244
  }
163
- let h;
164
- i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
165
- const c = this.wte.project(r);
245
+ let a;
246
+ i && r.length > 0 && (a = this.computeAttentionRollout(r)), l = this.lnF.apply(l);
247
+ const h = this.wte.project(l);
166
248
  let p;
167
- return e && (p = this.calculateLoss(c, e)), this.endMemory("Forward"), { logits: c, loss: p, attention: i ? h : void 0 };
249
+ return e && (p = this.calculateLoss(h, e)), this.endMemory("Forward"), { logits: h, loss: p, attention: i ? a : void 0 };
168
250
  });
169
251
  }
170
- generate(t, e, o) {
171
- const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
172
- return this.tf.tidy(() => {
173
- const n = t, h = n.shape[1], c = h <= this.config.blockSize ? n : n.slice(
174
- [0, h - this.config.blockSize],
175
- [n.shape[0], this.config.blockSize]
176
- ), p = l ? this.config.blockSize - c.shape[1] : 0, a = p > 0 ? this.tf.pad(c, [
252
+ generate(t, e, s) {
253
+ const i = s?.temperature ?? 1, o = s?.topK, n = s?.usePadding ?? !1, l = s?.includeAttention ?? !1;
254
+ return w(() => {
255
+ const r = t, a = r.shape[1], h = a <= this.config.blockSize ? r : r.slice(
256
+ [0, a - this.config.blockSize],
257
+ [r.shape[0], this.config.blockSize]
258
+ ), p = n ? this.config.blockSize - h.shape[1] : 0, c = p > 0 ? G(h, [
177
259
  [0, 0],
178
260
  [0, p]
179
- ]) : c, { logits: d, attention: g } = this.forward(a, void 0, !1, r, e), m = d.shape[1] - 1 - p, b = d.slice([0, m, 0], [d.shape[0], 1, d.shape[2]]), f = g ? g.slice([0, m, 0], [g.shape[0], 1, g.shape[2]]) : void 0, k = b.div(i);
180
- let u;
181
- if (s) {
182
- const { values: y, indices: E } = this.tf.topk(k, s), $ = this.tf.multinomial(y.squeeze([1]), 1);
183
- u = this.tf.gather(E.squeeze([1]), $, 1);
261
+ ]) : h, { logits: u, attention: d } = this.forward(c, void 0, !1, l, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = d ? d.slice([0, b, 0], [d.shape[0], 1, d.shape[2]]) : void 0, $ = k.div(i);
262
+ let g;
263
+ if (o) {
264
+ const { values: x, indices: M } = it($, o), W = S(x.squeeze([1]), 1);
265
+ g = Z(M.squeeze([1]), W, 1);
184
266
  } else
185
- u = this.tf.multinomial(k.squeeze([1]), 1);
186
- let w;
187
- return o?.includeProbabilities && (w = this.tf.softmax(k.squeeze([1]))), u = u.reshape([1, 1]), { output: u, attention: f?.squeeze([1]), probabilities: w };
267
+ g = S($.squeeze([1]), 1);
268
+ let z;
269
+ return s?.includeProbabilities && (z = J($.squeeze([1]))), g = g.reshape([1, 1]), { output: g, attention: f?.squeeze([1]), probabilities: z };
188
270
  });
189
271
  }
190
272
  getNumParams() {
191
- return F(this.config);
273
+ return K(this.config);
192
274
  }
193
275
  dispose() {
194
276
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
195
277
  }
196
278
  }
197
279
  export {
198
- A as default
280
+ Et as default
199
281
  };
@@ -1,4 +1,3 @@
1
- import { default as TF } from '@tensorflow/tfjs';
2
1
  import { GPTConfig } from './config';
3
2
  import { ITokeniser } from './tokeniser/type';
4
3
  import { default as NanoGPT } from './NanoGPTModel';
@@ -11,10 +10,9 @@ type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | '
11
10
  export default class TeachableLLM extends EE<'status' | 'error' | 'trainStep'> {
12
11
  private _config?;
13
12
  private _model?;
14
- readonly tf: typeof TF;
15
13
  private _tokeniser?;
16
14
  private _status;
17
- constructor(tf: typeof TF, tokeniser?: ITokeniser, model?: NanoGPT);
15
+ constructor(tokeniser?: ITokeniser, model?: NanoGPT);
18
16
  get config(): GPTConfig;
19
17
  get model(): NanoGPT;
20
18
  get tokeniser(): ITokeniser;
@@ -22,8 +20,8 @@ export default class TeachableLLM extends EE<'status' | 'error' | 'trainStep'> {
22
20
  get ready(): boolean;
23
21
  private setStatus;
24
22
  saveModel(options?: SaveOptions): Promise<Blob>;
25
- static loadModel(tf: typeof TF, data: Blob | Buffer | string): TeachableLLM;
26
- static create(tf: typeof TF, config?: Partial<GPTConfig>): TeachableLLM;
23
+ static loadModel(data: Blob | Buffer | string): TeachableLLM;
24
+ static create(tokeniserType: 'char' | 'bpe', config?: Partial<GPTConfig>): TeachableLLM;
27
25
  getProfiler(): MemoryProfiler | undefined;
28
26
  get enableProfiler(): boolean;
29
27
  set enableProfiler(value: boolean);
@@ -1,29 +1,41 @@
1
- import { defaultConfig as h } from "./config.js";
2
- import m from "./NanoGPTModel.js";
1
+ import { defaultConfig as l } from "./config.js";
2
+ import h from "./NanoGPTModel.js";
3
3
  import { saveModel as d } from "./utilities/save.js";
4
4
  import { loadModel as f } from "./utilities/load.js";
5
5
  import u from "./Generator.js";
6
6
  import _ from "./Trainer.js";
7
- import { E as c } from "./index-Dwqa6Zy2.js";
8
- import { dummyPassAsync as l } from "./utilities/dummy.js";
9
- import g from "./tokeniser/CharTokeniser.js";
7
+ import { E as p } from "./index-Dwqa6Zy2.js";
8
+ import { dummyPassAsync as m } from "./utilities/dummy.js";
9
+ import c from "./tokeniser/CharTokeniser.js";
10
+ import g from "./tokeniser/bpe.js";
10
11
  import "./papaparse.min-C8l2Kvo1.js";
11
12
  import "./index-Tf7vU29b.js";
12
13
  import "./jszip.min-CjP2V1VV.js";
13
- import "./ops/scatterSub.js";
14
- import "./ops/gatherSub.js";
15
- import "./ops/attentionMask.js";
16
- import "./ops/qkv.js";
17
- import "./ops/rope.js";
18
- import p from "./utilities/profile.js";
19
- class a extends c {
14
+ import "./index-pWA4_lUh.js";
15
+ import "./ops/cpu/scatterSub.js";
16
+ import "./ops/webgl/scatterSub.js";
17
+ import "./ops/cpu/gatherSub.js";
18
+ import "./ops/webgl/gatherSub.js";
19
+ import "./ops/cpu/attentionMask.js";
20
+ import "./ops/webgl/attentionMask.js";
21
+ import "./ops/grads/attentionMask.js";
22
+ import "./ops/cpu/qkv.js";
23
+ import "./ops/webgl/qkv.js";
24
+ import "./ops/grads/qkv.js";
25
+ import "@tensorflow/tfjs";
26
+ import "./ops/cpu/rope.js";
27
+ import "./ops/webgl/rope.js";
28
+ import "./ops/grads/rope.js";
29
+ import "./ops/cpu/appendCache.js";
30
+ import "./ops/webgl/appendCache.js";
31
+ import w from "./utilities/profile.js";
32
+ class a extends p {
20
33
  _config;
21
34
  _model;
22
- tf;
23
35
  _tokeniser;
24
36
  _status = "loading";
25
- constructor(t, r, e) {
26
- super(), this.tf = t, this._config = e?.config, this._tokeniser = r, this._model = e;
37
+ constructor(t, e) {
38
+ super(), this._config = e?.config, this._tokeniser = t, this._model = e;
27
39
  }
28
40
  get config() {
29
41
  if (!this._config)
@@ -54,21 +66,21 @@ class a extends c {
54
66
  throw new Error("Model or tokeniser is not initialized.");
55
67
  return d(this._model, this._tokeniser, t);
56
68
  }
57
- static loadModel(t, r) {
58
- const e = new a(t);
59
- return f(t, r).then(({ model: o, tokeniser: s }) => {
60
- e._model = o, e._tokeniser = s, e._config = o.config, e.setStatus("warmup"), l(o).then(() => {
69
+ static loadModel(t) {
70
+ const e = new a();
71
+ return f(t).then(({ model: r, tokeniser: o }) => {
72
+ e._model = r, e._tokeniser = o, e._config = r.config, e.setStatus("warmup"), m(r).then(() => {
61
73
  e.setStatus("ready");
62
- }).catch((i) => {
63
- e.setStatus("error"), e.emit("error", i);
74
+ }).catch((s) => {
75
+ e.setStatus("error"), e.emit("error", s);
64
76
  });
65
- }).catch((o) => {
66
- e.setStatus("error"), e.emit("error", o);
77
+ }).catch((r) => {
78
+ e.setStatus("error"), e.emit("error", r);
67
79
  }), e;
68
80
  }
69
- static create(t, r = {}) {
70
- const e = { ...h, ...r }, o = new g(e.vocabSize), s = new m(t, e), i = new a(t, o, s);
71
- return i.setStatus("warmup"), l(s).then(() => {
81
+ static create(t, e = {}) {
82
+ const r = { ...l, ...e }, o = t === "char" ? new c(r.vocabSize) : new g(r.vocabSize), s = new h(r), i = new a(o, s);
83
+ return i.setStatus("warmup"), m(s).then(() => {
72
84
  i.tokeniser.trained ? i.setStatus("ready") : (i.setStatus("awaitingTokens"), i.tokeniser.once("trainStatus", (n) => {
73
85
  n === "trained" && i.setStatus("ready");
74
86
  }));
@@ -86,7 +98,7 @@ class a extends c {
86
98
  if (t) {
87
99
  if (!this._model)
88
100
  throw new Error("Model is not initialized.");
89
- this._model.getProfiler() || this._model.setProfiler(new p());
101
+ this._model.getProfiler() || this._model.setProfiler(new w());
90
102
  } else
91
103
  this._model && this._model.setProfiler(void 0);
92
104
  }
@@ -99,14 +111,14 @@ class a extends c {
99
111
  if (!this._model || !this._tokeniser)
100
112
  throw new Error("Model or tokeniser is not initialized.");
101
113
  const t = new _(this._model, this._tokeniser);
102
- return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (r) => {
103
- const e = this.listeners("trainStep");
104
- for (const o of e)
105
- await o(r);
114
+ return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e) => {
115
+ const r = this.listeners("trainStep");
116
+ for (const o of r)
117
+ await o(e);
106
118
  }), t;
107
119
  }
108
- train(t, r) {
109
- return this.trainer().train(t, r);
120
+ train(t, e) {
121
+ return this.trainer().train(t, e);
110
122
  }
111
123
  generator() {
112
124
  if (!this._model || !this._tokeniser)
@@ -118,8 +130,8 @@ class a extends c {
118
130
  this.status === "busy" && this.setStatus("ready");
119
131
  }), t;
120
132
  }
121
- generateText(t, r) {
122
- return this.generator().generate(t, r);
133
+ generateText(t, e) {
134
+ return this.generator().generate(t, e);
123
135
  }
124
136
  dispose() {
125
137
  this._model?.dispose();