@genai-fi/nanogpt 0.6.2 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. package/dist/Generator.js +11 -11
  2. package/dist/NanoGPTModel.d.ts +2 -2
  3. package/dist/NanoGPTModel.js +104 -136
  4. package/dist/{RealDiv-BYViZwhN.js → RealDiv-C4hOvYOZ.js} +26 -25
  5. package/dist/{Reshape-t7Kcikjk.js → Reshape-BLijOA8h.js} +5 -5
  6. package/dist/TeachableLLM.d.ts +3 -0
  7. package/dist/TeachableLLM.js +50 -47
  8. package/dist/{TiedEmbedding-9WeDwvjO.js → TiedEmbedding-BLltddza.js} +4 -4
  9. package/dist/{axis_util-Bu4h7XWV.js → axis_util-DaAl5MER.js} +3 -3
  10. package/dist/backend.d.ts +1 -0
  11. package/dist/backend.js +7 -0
  12. package/dist/backend_util-DWiwsi2N.js +749 -0
  13. package/dist/{broadcast_to-DARN-DBD.js → broadcast_to-C4v-j9yA.js} +2 -2
  14. package/dist/{concat-5aPGqw3Z.js → concat-CsHeR4zV.js} +8 -8
  15. package/dist/{dataset-pgqp-YfL.js → dataset-JDyjG3QR.js} +3 -3
  16. package/dist/{dropout-Bciw46HT.js → dropout-hpDwECTe.js} +7 -7
  17. package/dist/{gather-DjyCjmOD.js → gather-D0_gPiBz.js} +4 -4
  18. package/dist/gelu-uyHP1x1f.js +26 -0
  19. package/dist/gpgpu_math-DJm3ZTAf.js +2371 -0
  20. package/dist/index-BPPzKVdR.js +12099 -0
  21. package/dist/{index-BAzbokzv.js → index-C0dhsYom.js} +405 -389
  22. package/dist/{kernel_funcs_utils-CUxJCg0g.js → kernel_funcs_utils-CwRTFqrc.js} +31 -30
  23. package/dist/layers/BaseLayer.js +2 -2
  24. package/dist/layers/CausalSelfAttention.js +6 -6
  25. package/dist/layers/MLP.js +5 -5
  26. package/dist/layers/RMSNorm.js +3 -3
  27. package/dist/layers/RoPECache.js +4 -4
  28. package/dist/layers/TiedEmbedding.js +5 -5
  29. package/dist/layers/TransformerBlock.js +1 -1
  30. package/dist/loader/loadTransformers.js +1 -1
  31. package/dist/loader/oldZipLoad.js +5 -5
  32. package/dist/{log_sum_exp-YEo2h3gb.js → log_sum_exp-D086OgZJ.js} +15 -15
  33. package/dist/main.d.ts +2 -0
  34. package/dist/main.js +9 -5
  35. package/dist/{mat_mul-7121rsJk.js → mat_mul-1nwdPkQ_.js} +4 -4
  36. package/dist/{max-DtlIuVeW.js → max-BQc2Aj-I.js} +4 -4
  37. package/dist/{mulmat_packed_gpu-D4nKF7Je.js → mulmat_packed_gpu-Gzf3I9UV.js} +1 -1
  38. package/dist/non_max_suppression_impl-CsEgBuMA.js +134 -0
  39. package/dist/{ones-BBlSRqn1.js → ones-D63HpSF_.js} +2 -2
  40. package/dist/ops/appendCache.js +3 -3
  41. package/dist/ops/attentionMask.js +1 -1
  42. package/dist/ops/cpu/appendCache.js +8 -8
  43. package/dist/ops/cpu/attentionMask.js +9 -9
  44. package/dist/ops/cpu/fusedSoftmax.js +17 -11
  45. package/dist/ops/cpu/gatherSub.js +7 -7
  46. package/dist/ops/cpu/gelu.js +13 -13
  47. package/dist/ops/cpu/matMulGelu.js +36 -24
  48. package/dist/ops/cpu/matMulMul.js +14 -8
  49. package/dist/ops/cpu/mulDropout.js +9 -3
  50. package/dist/ops/cpu/normRMS.js +5 -5
  51. package/dist/ops/cpu/qkv.js +3 -3
  52. package/dist/ops/cpu/rope.js +5 -5
  53. package/dist/ops/cpu/scatterSub.js +11 -11
  54. package/dist/ops/fusedSoftmax.js +1 -1
  55. package/dist/ops/gatherSub.js +1 -1
  56. package/dist/ops/gelu.js +2 -2
  57. package/dist/ops/grads/attentionMask.js +1 -1
  58. package/dist/ops/grads/fusedSoftmax.js +2 -2
  59. package/dist/ops/grads/gelu.js +3 -24
  60. package/dist/ops/grads/matMulGelu.js +5 -5
  61. package/dist/ops/grads/normRMS.js +6 -6
  62. package/dist/ops/grads/qkv.js +1 -1
  63. package/dist/ops/grads/rope.js +3 -3
  64. package/dist/ops/matMulGelu.js +1 -1
  65. package/dist/ops/matMulMul.js +1 -1
  66. package/dist/ops/mulDrop.js +1 -1
  67. package/dist/ops/normRMS.js +1 -1
  68. package/dist/ops/qkv.js +1 -1
  69. package/dist/ops/rope.js +4 -4
  70. package/dist/ops/scatterSub.js +1 -1
  71. package/dist/ops/webgl/appendCache.js +1 -1
  72. package/dist/ops/webgl/attentionMask.js +1 -1
  73. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  74. package/dist/ops/webgl/gatherSub.js +1 -1
  75. package/dist/ops/webgl/gelu.js +2 -2
  76. package/dist/ops/webgl/log.js +5 -5
  77. package/dist/ops/webgl/matMulGelu.js +17 -17
  78. package/dist/ops/webgl/matMulMul.js +1 -1
  79. package/dist/ops/webgl/mulDropout.js +4 -4
  80. package/dist/ops/webgl/normRMS.js +2 -2
  81. package/dist/ops/webgl/qkv.js +1 -1
  82. package/dist/ops/webgl/rope.js +1 -1
  83. package/dist/ops/webgl/scatterSub.js +1 -1
  84. package/dist/ops/webgpu/appendCache.js +56 -0
  85. package/dist/ops/webgpu/attentionMask.d.ts +1 -0
  86. package/dist/ops/webgpu/attentionMask.js +64 -0
  87. package/dist/ops/webgpu/gatherSub.d.ts +1 -0
  88. package/dist/ops/webgpu/gatherSub.js +37 -0
  89. package/dist/ops/webgpu/gelu.d.ts +14 -0
  90. package/dist/ops/webgpu/gelu.js +86 -0
  91. package/dist/ops/webgpu/index.d.ts +0 -0
  92. package/dist/ops/webgpu/index.js +8 -0
  93. package/dist/ops/webgpu/normRMS.d.ts +1 -0
  94. package/dist/ops/webgpu/normRMS.js +115 -0
  95. package/dist/ops/webgpu/qkv.d.ts +1 -0
  96. package/dist/ops/webgpu/qkv.js +56 -0
  97. package/dist/ops/webgpu/rope.d.ts +1 -0
  98. package/dist/ops/webgpu/rope.js +68 -0
  99. package/dist/ops/webgpu/scatterSub.d.ts +1 -0
  100. package/dist/ops/webgpu/scatterSub.js +37 -0
  101. package/dist/{ops-C0sQEcPw.js → ops-CIQLNshk.js} +452 -503
  102. package/dist/{random_width-DWzaOgrn.js → random_width-DkYP8W8N.js} +143 -144
  103. package/dist/{range-DYsrnfiy.js → range-CYzpQY53.js} +1 -1
  104. package/dist/{reciprocal-CJQeasVa.js → reciprocal-_A9yv27J.js} +1 -1
  105. package/dist/{register_all_kernels-BfFCQAqs.js → register_all_kernels-guvSxp7M.js} +202 -200
  106. package/dist/{reshape-krWGKraP.js → reshape-BMUzc1UY.js} +3 -3
  107. package/dist/{scatter_nd_util-93ln7Hut.js → scatter_nd_util-IRBqKz_b.js} +3 -3
  108. package/dist/{selu_util-sntGesxr.js → selu_util-Dt_iuXaq.js} +6 -6
  109. package/dist/shared-BNa2q6jD.js +69 -0
  110. package/dist/{shared-Ca6iDobD.js → shared-CDu9S76h.js} +541 -606
  111. package/dist/{sin-D_h-qCSx.js → sin-Cocju-BY.js} +6 -6
  112. package/dist/{softmax-fsdtf6JC.js → softmax-GPNK3o-U.js} +3 -3
  113. package/dist/{split-eiktj-6L.js → split-CHzJjxDv.js} +4 -4
  114. package/dist/{stack-dfEEz2OY.js → stack-Dpgg_1W1.js} +2 -2
  115. package/dist/{sum-BE_Irnim.js → sum-B8wEpKsg.js} +5 -5
  116. package/dist/{tensor-Xyi595sG.js → tensor-RvZVNmg0.js} +1 -1
  117. package/dist/{tensor2d-CPEkynbH.js → tensor2d-B_kyod7_.js} +1 -1
  118. package/dist/training/AdamExt.js +1 -1
  119. package/dist/training/DatasetBuilder.js +2 -2
  120. package/dist/training/Evaluator.js +1 -1
  121. package/dist/training/FullTrainer.js +20 -20
  122. package/dist/training/Trainer.d.ts +5 -6
  123. package/dist/training/Trainer.js +59 -60
  124. package/dist/training/sparseCrossEntropy.js +19 -26
  125. package/dist/utilities/dummy.js +19 -19
  126. package/dist/utilities/generate.js +15 -16
  127. package/dist/utilities/multinomialCPU.d.ts +2 -0
  128. package/dist/utilities/multinomialCPU.js +13 -0
  129. package/dist/utilities/performance.d.ts +2 -0
  130. package/dist/utilities/performance.js +16 -0
  131. package/dist/utilities/profile.d.ts +1 -0
  132. package/dist/utilities/profile.js +9 -6
  133. package/dist/utilities/safetensors.js +2 -2
  134. package/dist/utilities/weights.js +2 -2
  135. package/dist/{variable-wSS22xj5.js → variable-DXEUOwew.js} +1 -1
  136. package/dist/webgpu_util-g13LvDIv.js +625 -0
  137. package/dist/{zeros-YJDE7oRb.js → zeros-DCPCdFGq.js} +8 -8
  138. package/package.json +2 -1
  139. package/dist/gpgpu_math-CNslybmD.js +0 -3115
  140. package/dist/norm-CzltS9Fz.js +0 -86
  141. package/dist/ops/node/sparseCrossEntropy.js +0 -11
  142. /package/dist/ops/{node/sparseCrossEntropy.d.ts → webgpu/appendCache.d.ts} +0 -0
package/dist/Generator.js CHANGED
@@ -1,15 +1,15 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-BAzbokzv.js";
2
+ import "./index-C0dhsYom.js";
3
3
  import "./ops/cpu/attentionMask.js";
4
4
  import "./ops/webgl/attentionMask.js";
5
5
  import "./ops/grads/attentionMask.js";
6
6
  import "./ops/cpu/qkv.js";
7
7
  import "./ops/webgl/qkv.js";
8
8
  import "./ops/grads/qkv.js";
9
- import "./random_width-DWzaOgrn.js";
10
- import "./register_all_kernels-BfFCQAqs.js";
9
+ import "./random_width-DkYP8W8N.js";
10
+ import "./register_all_kernels-guvSxp7M.js";
11
11
  import "./index-Tf7vU29b.js";
12
- import "./dataset-pgqp-YfL.js";
12
+ import "./dataset-JDyjG3QR.js";
13
13
  import "./ops/cpu/rope.js";
14
14
  import "./ops/webgl/rope.js";
15
15
  import "./ops/grads/rope.js";
@@ -33,10 +33,10 @@ import f from "./tokeniser/CharTokeniser.js";
33
33
  import "./papaparse.min-C8l2Kvo1.js";
34
34
  import "./ops/cpu/gelu.js";
35
35
  import "./ops/webgl/gelu.js";
36
- import "./ops/grads/gelu.js";
36
+ import "./gelu-uyHP1x1f.js";
37
37
  import "./ops/webgl/log.js";
38
- import { t as d } from "./tensor2d-CPEkynbH.js";
39
- import { c as g } from "./concat-5aPGqw3Z.js";
38
+ import { t as d } from "./tensor2d-B_kyod7_.js";
39
+ import { c as g } from "./concat-CsHeR4zV.js";
40
40
  const k = [
41
41
  ...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
42
42
  // ASCII
@@ -49,7 +49,7 @@ const k = [
49
49
  // Cyrillic letters
50
50
  ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
51
51
  ];
52
- function T(a, t) {
52
+ function w(a, t) {
53
53
  return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
54
54
  }
55
55
  class rt extends u {
@@ -69,7 +69,7 @@ class rt extends u {
69
69
  output: e,
70
70
  attention: p,
71
71
  probabilities: c
72
- } = this.model.generate(i, void 0, r), h = i;
72
+ } = await this.model.generate(i, void 0, r), h = i;
73
73
  i = g([i, e], 1), h.dispose();
74
74
  const l = await this.processResponse(t, e, p, c);
75
75
  if (e.dispose(), l === null)
@@ -99,7 +99,7 @@ class rt extends u {
99
99
  output: p,
100
100
  probabilities: c,
101
101
  attention: h
102
- } = this.model.generate(i, n, {
102
+ } = await this.model.generate(i, n, {
103
103
  ...r,
104
104
  usePadding: !1
105
105
  });
@@ -116,7 +116,7 @@ class rt extends u {
116
116
  async generate(t, o) {
117
117
  const r = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t;
118
118
  this.active = !0, this.emit("start");
119
- const i = this.tokeniser.trained ? this.tokeniser : new f(T(k, this.tokeniser.vocabSize)), n = await (this.model.config.gpt.useRope && !o?.noCache ? this.generateCache(i, r, o) : this.generateNoCache(i, r, o));
119
+ const i = this.tokeniser.trained ? this.tokeniser : new f(w(k, this.tokeniser.vocabSize)), n = await (this.model.config.gpt.useRope && !o?.noCache ? this.generateCache(i, r, o) : this.generateNoCache(i, r, o));
120
120
  return this.active = !1, this.emit("stop"), n;
121
121
  }
122
122
  stop() {
@@ -42,11 +42,11 @@ export default class NanoGPT extends BaseLayer<ModelForwardAttributes> {
42
42
  private validateInput;
43
43
  private calculateLoss;
44
44
  forward(attrs: ModelForwardAttributes, idx: Tensor, targets?: Tensor): Tensor[];
45
- generate(idx: Tensor, cache?: KVCache[], options?: GenerateOptions): {
45
+ generate(idx: Tensor, cache?: KVCache[], options?: GenerateOptions): Promise<{
46
46
  output: Tensor;
47
47
  probabilities?: Tensor;
48
48
  attention?: Tensor[];
49
- };
49
+ }>;
50
50
  getNumParams(): number;
51
51
  dispose(): void;
52
52
  }
@@ -1,19 +1,19 @@
1
- import { defaultConfig as F } from "./config.js";
2
- import O from "./layers/TransformerBlock.js";
3
- import { T as _, r as D } from "./TiedEmbedding-9WeDwvjO.js";
4
- import K from "./layers/RoPECache.js";
5
- import N from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as R } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
- import G from "./layers/BaseLayer.js";
9
- import { E as B, D as V, p as j } from "./random_width-DWzaOgrn.js";
10
- import { o as W, q as H, E as J, a6 as Q, t as z, a7 as U, s as v, k as X } from "./index-BAzbokzv.js";
11
- import { m as Y, t as Z } from "./register_all_kernels-BfFCQAqs.js";
12
- import { r as L } from "./reshape-krWGKraP.js";
13
- import { r as tt } from "./range-DYsrnfiy.js";
14
- import { s as M } from "./softmax-fsdtf6JC.js";
15
- import { t as et } from "./ops-C0sQEcPw.js";
16
- import { g as ot } from "./gather-DjyCjmOD.js";
1
+ import { defaultConfig as M } from "./config.js";
2
+ import v from "./layers/TransformerBlock.js";
3
+ import { T as x, r as T } from "./TiedEmbedding-BLltddza.js";
4
+ import F from "./layers/RoPECache.js";
5
+ import O from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as _ } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as D } from "./training/sparseCrossEntropy.js";
8
+ import K from "./layers/BaseLayer.js";
9
+ import { E as N, D as R, p as q } from "./random_width-DkYP8W8N.js";
10
+ import { x as A, y as G, E as B, a5 as V, t as C, a6 as j, b as z, o as U } from "./index-C0dhsYom.js";
11
+ import W from "./utilities/multinomialCPU.js";
12
+ import { m as H, t as J } from "./register_all_kernels-guvSxp7M.js";
13
+ import { r as P } from "./reshape-BMUzc1UY.js";
14
+ import { r as Q } from "./range-CYzpQY53.js";
15
+ import { s as $ } from "./softmax-GPNK3o-U.js";
16
+ import { g as X } from "./gather-D0_gPiBz.js";
17
17
  /**
18
18
  * @license
19
19
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -30,17 +30,17 @@ import { g as ot } from "./gather-DjyCjmOD.js";
30
30
  * limitations under the License.
31
31
  * =============================================================================
32
32
  */
33
- function st(u, t, e, o = !1) {
34
- const r = H(u, "logits", "multinomial"), s = r.size, n = r.rank;
33
+ function Y(u, t, o, e = !1) {
34
+ const l = G(u, "logits", "multinomial"), s = l.size, r = l.rank;
35
35
  if (s < 2)
36
36
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
37
- if (n > 2)
38
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
39
- e = e || Math.random();
40
- const i = { logits: n === 1 ? L(r, [1, -1]) : r }, h = { numSamples: t, seed: e, normalized: o }, c = J.runKernel(Q, i, h);
41
- return n === 1 ? L(c, [c.size]) : c;
37
+ if (r > 2)
38
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
39
+ o = o || Math.random();
40
+ const n = { logits: r === 1 ? P(l, [1, -1]) : l }, a = { numSamples: t, seed: o, normalized: e }, i = B.runKernel(V, n, a);
41
+ return r === 1 ? P(i, [i.size]) : i;
42
42
  }
43
- const C = /* @__PURE__ */ W({ multinomial_: st });
43
+ const I = /* @__PURE__ */ A({ multinomial_: Y });
44
44
  /**
45
45
  * @license
46
46
  * Copyright 2018 Google LLC
@@ -50,13 +50,13 @@ const C = /* @__PURE__ */ W({ multinomial_: st });
50
50
  * https://opensource.org/licenses/MIT.
51
51
  * =============================================================================
52
52
  */
53
- function nt(u) {
54
- return new V(u);
53
+ function Z(u) {
54
+ return new R(u);
55
55
  }
56
- function it(u) {
57
- return new B(u);
56
+ function tt(u) {
57
+ return new N(u);
58
58
  }
59
- class St extends G {
59
+ class bt extends K {
60
60
  wte;
61
61
  // Token embeddings
62
62
  wpe;
@@ -70,15 +70,15 @@ class St extends G {
70
70
  log = [];
71
71
  // Training log
72
72
  constructor(t = {}) {
73
- super({ gpt: { ...F, ...t }, layerConfig: {} }), this.wte = new _(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = it({
73
+ super({ gpt: { ...M, ...t }, layerConfig: {} }), this.wte = new x(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = tt({
74
74
  inputDim: this.config.gpt.blockSize,
75
75
  outputDim: this.config.gpt.nEmbed,
76
76
  name: "positional_embedding",
77
- embeddingsInitializer: D({ mean: 0, stddev: 0.02 })
78
- }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = nt({ rate: this.config.gpt.dropout }), this.blocks = [];
79
- for (let e = 0; e < this.config.gpt.nLayer; e++)
80
- this.blocks.push(new O(e, this.config, this));
81
- this.lnF = new N(this.config, "final_rms_norm", this);
77
+ embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
78
+ }) : (this.ropeCache = new F(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = Z({ rate: this.config.gpt.dropout }), this.blocks = [];
79
+ for (let o = 0; o < this.config.gpt.nLayer; o++)
80
+ this.blocks.push(new v(o, this.config, this));
81
+ this.lnF = new O(this.config, "final_rms_norm", this);
82
82
  }
83
83
  get checkpointing() {
84
84
  return this.config.layerConfig.checkpointing === !0;
@@ -86,27 +86,27 @@ class St extends G {
86
86
  set checkpointing(t) {
87
87
  this.config.layerConfig.checkpointing = t;
88
88
  }
89
- inputPhase(t, e, o = !1) {
90
- return z(() => {
91
- const r = this.wte.embed(t);
89
+ inputPhase(t, o, e = !1) {
90
+ return C(() => {
91
+ const l = this.wte.embed(t);
92
92
  if (this.config.gpt.useRope === !1) {
93
- const [, s] = t.shape, n = this.config.gpt.blockSize, p = tt(0, s, 1, "int32"), i = Y(U(p, v(e, "int32")), v(n, "int32")), h = this.wpe.apply(i), c = r.add(h);
94
- return this.drop.apply(c, { training: o });
93
+ const [, s] = t.shape, r = this.config.gpt.blockSize, g = Q(0, s, 1, "int32"), n = H(j(g, z(o, "int32")), z(r, "int32")), a = this.wpe.apply(n), i = l.add(a);
94
+ return this.drop.apply(i, { training: e });
95
95
  } else
96
- return this.drop.apply(r, { training: o });
96
+ return this.drop.apply(l, { training: e });
97
97
  });
98
98
  }
99
99
  setSkipMask(t) {
100
100
  if (t.length !== this.blocks.length)
101
101
  throw new Error(`Mask length ${t.length} does not match number of blocks ${this.blocks.length}`);
102
- for (let e = 0; e < this.blocks.length; e++)
103
- this.blocks[e].skipped = t[e];
102
+ for (let o = 0; o < this.blocks.length; o++)
103
+ this.blocks[o].skipped = t[o];
104
104
  }
105
105
  setTrainableMask(t) {
106
106
  if (t.length !== this.blocks.length)
107
107
  throw new Error(`Mask length ${t.length} does not match number of blocks ${this.blocks.length}`);
108
- for (let e = 0; e < this.blocks.length; e++)
109
- this.blocks[e].trainable = t[e];
108
+ for (let o = 0; o < this.blocks.length; o++)
109
+ this.blocks[o].trainable = t[o];
110
110
  }
111
111
  validateInput(t) {
112
112
  if (t.shape.length !== 2)
@@ -116,120 +116,88 @@ class St extends G {
116
116
  if (t.dtype !== "int32")
117
117
  throw new Error(`Input tensor must be of type int32, got ${t.dtype}`);
118
118
  }
119
- calculateLoss(t, e) {
119
+ calculateLoss(t, o) {
120
120
  try {
121
- return A()(t, e).mean();
122
- } catch (o) {
123
- throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
121
+ return D()(t, o).mean();
122
+ } catch (e) {
123
+ throw console.error("Error computing loss:", e), new Error(`Loss computation failed: ${e}`);
124
124
  }
125
125
  }
126
- // Attention rollout per Abnar & Zuidema (2020)
127
- // Expects list of (B, T, T) attention matrices already averaged over heads.
128
- /*private computeAttentionRollout(attentions: Tensor[]): Tensor {
129
- return tidy(() => {
130
- if (attentions.length === 0) {
131
- throw new Error('No attentions for rollout');
132
- }
133
- const [B, Q, K] = attentions[0].shape as number[];
134
-
135
- // Validate shapes are consistent
136
- for (const a of attentions) {
137
- const [b2, q2, k2] = a.shape as number[];
138
- if (b2 !== B || q2 !== Q || k2 !== K) {
139
- throw new Error(
140
- `Inconsistent attention shapes in rollout: expected [${B},${Q},${K}] got [${b2},${q2},${k2}]`
141
- );
142
- }
143
- }
144
-
145
- // Always slice to [B, Q, Q] for rollout
146
- const attentionsSliced = attentions.map((att) => att.slice([0, 0, 0], [B, Q, Q]));
147
-
148
- const ey = eye(Q, Q).expandDims(0); // (1,Q,Q)
149
- let rollout = ey.tile([B, 1, 1]); // (B,Q,Q)
150
- for (const att of attentionsSliced) {
151
- const a = att.add(ey);
152
- const aNorm = a.div(a.sum(-1, true)); // (B,Q,Q)
153
- rollout = aNorm.matMul(rollout); // (B,Q,Q)
154
- }
155
- return rollout;
156
- });
157
- }*/
158
- forward(t, e, o) {
159
- return this.validateInput(e), z(() => {
126
+ forward(t, o, e) {
127
+ return this.validateInput(o), C(() => {
160
128
  this.startMemory();
161
- const r = t.cache?.[0]?.length ?? 0;
162
- let s = this.inputPhase(e, r, t.training);
129
+ const l = t.cache?.[0]?.length ?? 0;
130
+ let s = this.inputPhase(o, l, t.training);
163
131
  if (t.cache && t.cache.length !== this.blocks.length)
164
132
  throw console.error("Cache", t.cache), new Error(
165
133
  `Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
166
134
  );
167
- for (let i = 0; i < this.blocks.length; i++) {
168
- const h = this.blocks[i], c = Math.random() * 1e9, g = {
135
+ for (let n = 0; n < this.blocks.length; n++) {
136
+ const a = this.blocks[n], i = Math.random() * 1e9, d = {
169
137
  training: t.training,
170
- seed: c,
138
+ seed: i,
171
139
  attentionScores: t.attentionScores,
172
- pastKV: t.cache ? t.cache[i] : void 0
173
- }, E = this.config.layerConfig.checkpointing && t.training ? h.callCheckpoint(g, s) : h.call(g, s);
174
- s.dispose(), s = E;
140
+ pastKV: t.cache ? t.cache[n] : void 0
141
+ }, S = this.config.layerConfig.checkpointing && t.training ? a.callCheckpoint(d, s) : a.call(d, s);
142
+ s.dispose(), s = S;
175
143
  }
176
144
  s = this.lnF.call(t, s);
177
- const n = this.wte.project(s);
145
+ const r = this.wte.project(s);
178
146
  s.dispose();
179
- let p;
180
- return o && (p = this.calculateLoss(n, o)), this.endMemory("Forward"), p ? [n, p] : [n];
147
+ let g;
148
+ return e && (g = this.calculateLoss(r, e)), this.endMemory("Forward"), g ? [r, g] : [r];
181
149
  });
182
150
  }
183
- generate(t, e, o) {
184
- const r = o?.temperature ?? 1, s = o?.topK, n = o?.topP, p = o?.usePadding ?? !1;
185
- return z(() => {
186
- const i = t, h = i.shape[1], c = h <= this.config.gpt.blockSize ? i : i.slice(
187
- [0, h - this.config.gpt.blockSize],
188
- [i.shape[0], this.config.gpt.blockSize]
189
- ), g = p ? this.config.gpt.blockSize - c.shape[1] : 0, E = g > 0 ? j(c, [
151
+ async generate(t, o, e) {
152
+ const l = e?.temperature ?? 1, s = e?.topK, r = e?.topP, g = e?.usePadding ?? !1, n = {
153
+ training: !1,
154
+ attentionScores: e?.attentionScores ? {
155
+ attentionOut: []
156
+ } : void 0,
157
+ cache: o
158
+ }, a = C(() => {
159
+ const p = t, m = p.shape[1], h = m <= this.config.gpt.blockSize ? p : p.slice(
160
+ [0, m - this.config.gpt.blockSize],
161
+ [p.shape[0], this.config.gpt.blockSize]
162
+ ), b = g ? this.config.gpt.blockSize - h.shape[1] : 0, w = b > 0 ? q(h, [
190
163
  [0, 0],
191
- [0, g]
192
- ]) : c, f = {
193
- training: !1,
194
- attentionScores: o?.attentionScores ? {
195
- attentionOut: []
196
- } : void 0,
197
- cache: e
198
- }, [d] = this.forward(f, E), $ = d.shape[1] - 1 - g, q = d.slice([0, $, 0], [d.shape[0], 1, d.shape[2]]);
199
- f.attentionScores?.attentionOut && f.attentionScores.attentionOut.forEach((l, b) => {
200
- l.shape[1] !== 1 && (f.attentionScores.attentionOut[b] = X(
201
- l.slice([0, $, 0], [l.shape[0], 1, l.shape[2]])
202
- ), l.dispose());
203
- }), d.dispose();
204
- const w = q.div(r);
205
- let m;
206
- if (n) {
207
- const l = M(w.squeeze([1])), b = l.arraySync()[0];
208
- l.dispose();
209
- const y = b.map((a, k) => ({ prob: a, index: k })).sort((a, k) => k.prob - a.prob);
210
- let P = 0;
211
- const S = new Array(y.length).fill(0);
212
- for (const a of y)
213
- if (P += a.prob, S[a.index] = a.prob, P >= n)
214
- break;
215
- const x = S.reduce((a, k) => a + k, 0), T = S.map((a) => a / x);
216
- m = C(et(T), 1, void 0, !0);
217
- } else if (s) {
218
- const { values: l, indices: b } = Z(w, s), y = C(l.squeeze([1]), 1);
219
- m = ot(b.squeeze([1]), y, 1);
220
- } else
221
- m = C(w.squeeze([1]), 1);
222
- let I;
223
- return o?.includeProbabilities && (I = M(w.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: I, attention: f.attentionScores?.attentionOut };
164
+ [0, b]
165
+ ]) : h, [f] = this.forward(n, w), E = f.shape[1] - 1 - b, c = f.slice([0, E, 0], [f.shape[0], 1, f.shape[2]]);
166
+ return n.attentionScores?.attentionOut && n.attentionScores.attentionOut.forEach((y, L) => {
167
+ y.shape[1] !== 1 && (n.attentionScores.attentionOut[L] = U(
168
+ y.slice([0, E, 0], [y.shape[0], 1, y.shape[2]])
169
+ ), y.dispose());
170
+ }), f.dispose(), c.div(l).squeeze([1]);
224
171
  });
172
+ let i;
173
+ if (r) {
174
+ const p = $(a), m = await p.array();
175
+ p.dispose();
176
+ const h = m[0].map((c, k) => ({ prob: c, index: k })).sort((c, k) => k.prob - c.prob);
177
+ let b = 0;
178
+ const w = new Array(h.length).fill(0);
179
+ for (const c of h)
180
+ if (b += c.prob, w[c.index] = c.prob, b >= r)
181
+ break;
182
+ const f = w.reduce((c, k) => c + k, 0), E = w.map((c) => c / f);
183
+ i = W(E);
184
+ } else if (s) {
185
+ const { values: p, indices: m } = J(a, s), h = I(p, 1);
186
+ i = X(m, h, 1), p.dispose(), m.dispose(), h.dispose();
187
+ } else
188
+ i = I(a, 1);
189
+ let d;
190
+ e?.includeProbabilities && (d = $(a));
191
+ const S = i.reshape([1, 1]);
192
+ return i.dispose(), i = S, a.dispose(), { output: i, probabilities: d, attention: n.attentionScores?.attentionOut };
225
193
  }
226
194
  getNumParams() {
227
- return R(this.config.gpt);
195
+ return _(this.config.gpt);
228
196
  }
229
197
  dispose() {
230
198
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
231
199
  }
232
200
  }
233
201
  export {
234
- St as default
202
+ bt as default
235
203
  };
@@ -1,9 +1,10 @@
1
- import { ao as T, ac as E, p as O, g as V, aw as B, N as F, M as j, ax as K } from "./index-BAzbokzv.js";
2
- import { r as $ } from "./Reshape-t7Kcikjk.js";
3
- import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-Bu4h7XWV.js";
4
- import { t as U, m as W } from "./shared-Ca6iDobD.js";
5
- import { j as _, f as y } from "./gpgpu_math-CNslybmD.js";
6
- import { g as G, b as L } from "./kernel_funcs_utils-CUxJCg0g.js";
1
+ import { ao as T, ac as E, p as O, j as V, aw as B, U as F, N as U, ax as j } from "./index-C0dhsYom.js";
2
+ import { r as $ } from "./Reshape-BLijOA8h.js";
3
+ import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-DaAl5MER.js";
4
+ import { t as K, m as W } from "./shared-BNa2q6jD.js";
5
+ import { c as _ } from "./backend_util-DWiwsi2N.js";
6
+ import { f as y } from "./gpgpu_math-DJm3ZTAf.js";
7
+ import { g as G, b as L } from "./kernel_funcs_utils-CwRTFqrc.js";
7
8
  /**
8
9
  * @license
9
10
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -240,7 +241,7 @@ function q(a) {
240
241
  }
241
242
  return s;
242
243
  }
243
- function M(a, s, e, t) {
244
+ function P(a, s, e, t) {
244
245
  const n = q(a.shape);
245
246
  let l = a;
246
247
  for (let r = 0; r < n.length; r++) {
@@ -355,7 +356,7 @@ class J {
355
356
  * limitations under the License.
356
357
  * =============================================================================
357
358
  */
358
- function P(a, s, e) {
359
+ function D(a, s, e) {
359
360
  const t = E().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, s) : new Y(a.shape, s);
360
361
  return e.runWebGLProgram(t, [a], a.dtype);
361
362
  }
@@ -380,11 +381,11 @@ function Q(a, s, e, t) {
380
381
  let i = r;
381
382
  const c = A(i, l), o = c != null;
382
383
  let u = a;
383
- o && (u = P(a, c, t), i = k(i.length, l)), C("sum", i, l);
384
+ o && (u = D(a, c, t), i = k(i.length, l)), C("sum", i, l);
384
385
  const [p, h] = N(u.shape, i);
385
386
  let d = p;
386
387
  e && (d = R(p, r));
387
- const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = B(a.dtype), I = M(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
388
+ const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = B(a.dtype), I = P(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
388
389
  return t.disposeIntermediateTensorInfo(x), t.disposeIntermediateTensorInfo(I), o && t.disposeIntermediateTensorInfo(u), m;
389
390
  }
390
391
  /**
@@ -407,7 +408,7 @@ function Z(a) {
407
408
  const { inputs: s, backend: e, attrs: t } = a, { x: n } = s, { axis: l, keepDims: r } = t;
408
409
  return Q(n, l, r, e);
409
410
  }
410
- const de = {
411
+ const pe = {
411
412
  kernelName: F,
412
413
  backendName: "webgl",
413
414
  kernelFunc: Z
@@ -429,7 +430,7 @@ const de = {
429
430
  * =============================================================================
430
431
  */
431
432
  function ee(a, s, e, t) {
432
- const n = V(s), r = V(a.shape) / n, i = $({ inputs: { x: a }, attrs: { shape: [r, n] }, backend: t }), c = M(i, a.dtype, "max", t), o = $({ inputs: { x: c }, attrs: { shape: e }, backend: t });
433
+ const n = V(s), r = V(a.shape) / n, i = $({ inputs: { x: a }, attrs: { shape: [r, n] }, backend: t }), c = P(i, a.dtype, "max", t), o = $({ inputs: { x: c }, attrs: { shape: e }, backend: t });
433
434
  return t.disposeIntermediateTensorInfo(i), t.disposeIntermediateTensorInfo(c), o;
434
435
  }
435
436
  /**
@@ -458,12 +459,12 @@ function te(a) {
458
459
  const I = e.texData.get(d.dataId).values, m = new Array(i);
459
460
  for (let v = 0; v < m.length; v++)
460
461
  m[v] = n.shape[u[v]];
461
- const z = U(I, n.shape, n.dtype, u, m);
462
+ const z = K(I, n.shape, n.dtype, u, m);
462
463
  d = e.makeTensorInfo(m, n.dtype);
463
- const D = e.texData.get(d.dataId);
464
- D.values = z;
464
+ const M = e.texData.get(d.dataId);
465
+ M.values = z;
465
466
  } else
466
- d = P(n, u, e);
467
+ d = D(n, u, e);
467
468
  o = k(o.length, i);
468
469
  }
469
470
  C("max", o, i);
@@ -480,8 +481,8 @@ function te(a) {
480
481
  x = ee(d, S, g, e);
481
482
  return p && e.disposeIntermediateTensorInfo(d), x;
482
483
  }
483
- const pe = {
484
- kernelName: j,
484
+ const he = {
485
+ kernelName: U,
485
486
  backendName: "webgl",
486
487
  kernelFunc: te
487
488
  };
@@ -523,18 +524,18 @@ return a / b;`, se = `
523
524
  }
524
525
 
525
526
  return result;
526
- `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), he = {
527
- kernelName: K,
527
+ `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
528
+ kernelName: j,
528
529
  backendName: "webgl",
529
530
  kernelFunc: ne
530
531
  };
531
532
  export {
532
- M as a,
533
- pe as b,
534
- he as c,
535
- de as d,
533
+ P as a,
534
+ he as b,
535
+ fe as c,
536
+ pe as d,
536
537
  te as m,
537
538
  ne as r,
538
539
  Z as s,
539
- P as t
540
+ D as t
540
541
  };
@@ -1,5 +1,5 @@
1
- import { g as c, aa as C, i as f, D as R } from "./index-BAzbokzv.js";
2
- import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-CNslybmD.js";
1
+ import { j as c, a9 as C, l as f, I as R } from "./index-C0dhsYom.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as l, i as m } from "./gpgpu_math-DJm3ZTAf.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -82,14 +82,14 @@ function v(s, t) {
82
82
  function b(s, t, i) {
83
83
  const a = [
84
84
  u(s.shape),
85
- ...m(s.shape)
85
+ ...l(s.shape)
86
86
  ], e = {
87
87
  dtype: s.dtype,
88
88
  shape: a,
89
89
  dataId: s.dataId
90
90
  }, o = [
91
91
  u(t),
92
- ...m(t)
92
+ ...l(t)
93
93
  ], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
94
94
  return { dataId: h.dataId, shape: t, dtype: h.dtype };
95
95
  }
@@ -113,7 +113,7 @@ function y(s) {
113
113
  const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = C(o, p), h = c(n);
114
114
  f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
115
115
  const d = r.texData.get(e.dataId);
116
- return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
116
+ return d.isPacked && !m(e.shape, n) && !(d.texture !== null && m(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
117
117
  }
118
118
  const U = {
119
119
  kernelName: R,
@@ -22,12 +22,15 @@ export default class TeachableLLM {
22
22
  meta: TeachableLLMMeta;
23
23
  constructor(tokeniser?: ITokeniser, model?: NanoGPT);
24
24
  get vocab(): string[];
25
+ /** Model is fully loaded */
25
26
  get loaded(): boolean;
26
27
  get config(): GPTConfig;
27
28
  get model(): NanoGPT;
28
29
  get tokeniser(): ITokeniser;
29
30
  get status(): TeachableLLMStatus;
31
+ /** Model is both ready and not busy */
30
32
  get ready(): boolean;
33
+ get busy(): boolean;
31
34
  estimateTrainingMemoryUsage(batchSize: number): number;
32
35
  private setStatus;
33
36
  saveModel(options?: SaveOptions): Promise<Blob>;