@genai-fi/nanogpt 0.7.2 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. package/dist/Generator.d.ts +36 -4
  2. package/dist/Generator.js +183 -69
  3. package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-N8TpOMYv.js} +14 -14
  4. package/dist/{Reshape-DvudQDvJ.js → Reshape-B-lWQRnF.js} +1 -1
  5. package/dist/{Reshape-DH5srBP0.js → Reshape-Bo8HzP8V.js} +5 -5
  6. package/dist/TeachableLLM.d.ts +6 -6
  7. package/dist/TeachableLLM.js +51 -50
  8. package/dist/Trainer.d.ts +19 -3
  9. package/dist/Trainer.js +71 -28
  10. package/dist/{axis_util-BzbKo31C.js → axis_util-DubwyOhW.js} +3 -3
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-TE7aTPhZ.js → backend_util-BJ-_jSeK.js} +46 -46
  13. package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-BYfCp5iL.js} +2 -2
  14. package/dist/{concat-CsxrgovM.js → concat-BmDqqFsa.js} +1 -1
  15. package/dist/{dataset-CtdBYwjo.js → dataset-CJmEGu6D.js} +5 -5
  16. package/dist/{dropout-DYs5QFGQ.js → dropout-sx0sjVAT.js} +8 -8
  17. package/dist/exports_initializers-DAKM8UO9.js +16 -0
  18. package/dist/{gather-CMMy2KEG.js → gather-C1siEkdp.js} +1 -1
  19. package/dist/{gelu-C-dPj6Ku.js → gelu-Bd3UBBxg.js} +1 -1
  20. package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-TFLxaLkw.js} +26 -26
  21. package/dist/{index-CLthM0TO.js → index-BaPo_0H8.js} +185 -185
  22. package/dist/{index-BoWRt-10.js → index-CUQrfsw_.js} +266 -265
  23. package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-P9aFa232.js} +9 -9
  24. package/dist/layers/BaseLayer.d.ts +8 -13
  25. package/dist/layers/BaseLayer.js +25 -13
  26. package/dist/layers/CausalSelfAttention.d.ts +3 -2
  27. package/dist/layers/CausalSelfAttention.js +28 -28
  28. package/dist/layers/MLP.d.ts +3 -2
  29. package/dist/layers/MLP.js +16 -20
  30. package/dist/layers/PositionEmbedding.d.ts +9 -0
  31. package/dist/layers/PositionEmbedding.js +45 -0
  32. package/dist/layers/RMSNorm.d.ts +3 -2
  33. package/dist/layers/RMSNorm.js +6 -6
  34. package/dist/layers/RoPECache.d.ts +1 -1
  35. package/dist/layers/RoPECache.js +4 -4
  36. package/dist/layers/TiedEmbedding.d.ts +3 -2
  37. package/dist/layers/TiedEmbedding.js +29 -7
  38. package/dist/layers/TransformerBlock.d.ts +3 -2
  39. package/dist/layers/TransformerBlock.js +1 -1
  40. package/dist/loader/load.d.ts +2 -2
  41. package/dist/loader/loadHF.d.ts +2 -2
  42. package/dist/loader/loadTransformers.d.ts +4 -2
  43. package/dist/loader/loadTransformers.js +10 -9
  44. package/dist/loader/newZipLoad.d.ts +2 -2
  45. package/dist/loader/oldZipLoad.d.ts +2 -2
  46. package/dist/loader/oldZipLoad.js +42 -51
  47. package/dist/loader/save.d.ts +8 -0
  48. package/dist/loader/save.js +62 -0
  49. package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C142qZqY.js} +14 -14
  50. package/dist/main.d.ts +5 -4
  51. package/dist/main.js +22 -18
  52. package/dist/{mat_mul-8m8pfdcx.js → mat_mul-DMkduNJu.js} +1 -1
  53. package/dist/{max-Ddnnb5xe.js → max-B3JOcNGb.js} +1 -1
  54. package/dist/mod-uUuj4gSb.js +27 -0
  55. package/dist/models/NanoGPTV1.d.ts +15 -0
  56. package/dist/models/NanoGPTV1.js +71 -0
  57. package/dist/{config.d.ts → models/config.d.ts} +1 -0
  58. package/dist/{config.js → models/config.js} +1 -0
  59. package/dist/models/factory.d.ts +3 -0
  60. package/dist/models/factory.js +14 -0
  61. package/dist/models/model.d.ts +26 -0
  62. package/dist/models/model.js +68 -0
  63. package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-Cm2gw-c8.js} +1 -1
  64. package/dist/{ones-Dj0SDhHf.js → ones-ZdgQGBCP.js} +2 -2
  65. package/dist/ops/adamAdjust.js +1 -1
  66. package/dist/ops/adamMoments.js +1 -1
  67. package/dist/ops/appendCache.js +3 -3
  68. package/dist/ops/attentionMask.js +1 -1
  69. package/dist/ops/cpu/adamAdjust.js +9 -9
  70. package/dist/ops/cpu/adamMoments.js +2 -2
  71. package/dist/ops/cpu/appendCache.js +2 -2
  72. package/dist/ops/cpu/attentionMask.js +5 -5
  73. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  74. package/dist/ops/cpu/gatherSub.js +3 -3
  75. package/dist/ops/cpu/gelu.js +1 -1
  76. package/dist/ops/cpu/matMulGelu.js +2 -2
  77. package/dist/ops/cpu/matMulMul.js +1 -1
  78. package/dist/ops/cpu/mulDropout.js +1 -1
  79. package/dist/ops/cpu/normRMS.js +1 -1
  80. package/dist/ops/cpu/qkv.js +3 -3
  81. package/dist/ops/cpu/rope.js +5 -5
  82. package/dist/ops/cpu/scatterSub.js +11 -11
  83. package/dist/ops/fusedSoftmax.js +1 -1
  84. package/dist/ops/gatherSub.js +1 -1
  85. package/dist/ops/gelu.js +2 -2
  86. package/dist/ops/grads/attentionMask.js +1 -1
  87. package/dist/ops/grads/fusedSoftmax.js +2 -2
  88. package/dist/ops/grads/gelu.js +2 -2
  89. package/dist/ops/grads/matMulGelu.js +1 -1
  90. package/dist/ops/grads/normRMS.js +1 -1
  91. package/dist/ops/grads/qkv.js +1 -1
  92. package/dist/ops/grads/rope.js +1 -1
  93. package/dist/ops/matMulGelu.js +1 -1
  94. package/dist/ops/matMulMul.js +1 -1
  95. package/dist/ops/mulDrop.js +1 -1
  96. package/dist/ops/normRMS.js +1 -1
  97. package/dist/ops/qkv.js +1 -1
  98. package/dist/ops/rope.js +4 -4
  99. package/dist/ops/scatterSub.js +1 -1
  100. package/dist/ops/webgl/adamAdjust.js +2 -2
  101. package/dist/ops/webgl/adamMoments.js +1 -1
  102. package/dist/ops/webgl/appendCache.js +1 -1
  103. package/dist/ops/webgl/attentionMask.js +1 -1
  104. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  105. package/dist/ops/webgl/gatherSub.js +1 -1
  106. package/dist/ops/webgl/gelu.js +2 -2
  107. package/dist/ops/webgl/log.js +3 -3
  108. package/dist/ops/webgl/matMulGelu.js +10 -10
  109. package/dist/ops/webgl/matMulMul.js +1 -1
  110. package/dist/ops/webgl/mulDropout.js +1 -1
  111. package/dist/ops/webgl/normRMS.js +2 -2
  112. package/dist/ops/webgl/qkv.js +1 -1
  113. package/dist/ops/webgl/rope.js +1 -1
  114. package/dist/ops/webgl/scatterSub.js +1 -1
  115. package/dist/ops/webgpu/adamAdjust.js +3 -3
  116. package/dist/ops/webgpu/adamMoments.js +3 -3
  117. package/dist/ops/webgpu/appendCache.js +3 -3
  118. package/dist/ops/webgpu/attentionMask.js +3 -3
  119. package/dist/ops/webgpu/gatherSub.js +3 -3
  120. package/dist/ops/webgpu/gelu.js +3 -3
  121. package/dist/ops/webgpu/normRMS.js +2 -2
  122. package/dist/ops/webgpu/normRMSGrad.js +5 -5
  123. package/dist/ops/webgpu/qkv.js +3 -3
  124. package/dist/ops/webgpu/rope.js +3 -3
  125. package/dist/ops/webgpu/scatterSub.js +3 -3
  126. package/dist/ops/webgpu/utils/reductions.js +4 -4
  127. package/dist/{ops-BFGCx8Ri.js → ops-C_1K_-35.js} +103 -103
  128. package/dist/{random_width-sZORGo5k.js → random_width-D8Pwy_na.js} +136 -136
  129. package/dist/{range-CRuAh-gd.js → range-LVHrSLdi.js} +1 -1
  130. package/dist/{reciprocal-BvGAyKyu.js → reciprocal-CaR9e67G.js} +1 -1
  131. package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-DUshvVWP.js} +2026 -2049
  132. package/dist/{reshape-CdBq1WJ6.js → reshape-DEfQGSin.js} +1 -1
  133. package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-CUPPNLaA.js} +1 -1
  134. package/dist/{selu_util-BJEXVvjX.js → selu_util-8vv5JxQV.js} +3 -3
  135. package/dist/{shared-B8ztnyEk.js → shared-CkNorDcU.js} +83 -83
  136. package/dist/{shared-wS99K7_n.js → shared-D1elLckx.js} +1 -1
  137. package/dist/{sin-BeA3tsEd.js → sin-D2CKKmyR.js} +1 -1
  138. package/dist/{slice-BiOsknYS.js → slice-BnyE-M_7.js} +1 -1
  139. package/dist/{softmax-Bv_6lyMX.js → softmax-DLoZWYBx.js} +1 -1
  140. package/dist/{split-B-dikLRw.js → split-By_n4TKP.js} +1 -1
  141. package/dist/{stack-B17UN2nn.js → stack-DkdFLq37.js} +1 -1
  142. package/dist/{sum-66ew2byf.js → sum-l_0SqM4h.js} +3 -3
  143. package/dist/{tensor-JwS7ZYY6.js → tensor-BAQdLqoU.js} +1 -1
  144. package/dist/{tensor2d-wxPAnDQy.js → tensor2d-BHy261cI.js} +1 -1
  145. package/dist/training/Adam.js +2 -2
  146. package/dist/training/AdamExt.js +1 -1
  147. package/dist/training/DatasetBuilder.js +2 -2
  148. package/dist/training/Evaluator.d.ts +2 -2
  149. package/dist/training/FullTrainer.d.ts +16 -3
  150. package/dist/training/FullTrainer.js +91 -53
  151. package/dist/training/Trainer.d.ts +25 -3
  152. package/dist/training/Trainer.js +39 -47
  153. package/dist/training/sparseCrossEntropy.js +9 -9
  154. package/dist/utilities/dummy.d.ts +4 -4
  155. package/dist/utilities/dummy.js +13 -13
  156. package/dist/utilities/multinomialCPU.js +2 -2
  157. package/dist/utilities/parameters.d.ts +1 -1
  158. package/dist/utilities/performance.js +1 -1
  159. package/dist/utilities/profile.js +1 -1
  160. package/dist/utilities/safetensors.js +2 -2
  161. package/dist/utilities/weights.js +2 -2
  162. package/dist/{variable-BuddVFLa.js → variable-C9hihzDB.js} +1 -1
  163. package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-dFEVbDPL.js} +1 -1
  164. package/dist/{webgpu_util-D____QpY.js → webgpu_util-DLImlSc6.js} +27 -27
  165. package/dist/{zeros--BdLQ3oG.js → zeros-VZ72lWXM.js} +1 -1
  166. package/package.json +2 -3
  167. package/dist/NanoGPTModel.d.ts +0 -52
  168. package/dist/NanoGPTModel.js +0 -203
  169. package/dist/TiedEmbedding-BxOerUmB.js +0 -43
  170. package/dist/utilities/generate.d.ts +0 -3
  171. package/dist/utilities/generate.js +0 -22
  172. package/dist/utilities/save.d.ts +0 -9
  173. package/dist/utilities/save.js +0 -61
@@ -1,19 +1,51 @@
1
- import { default as NanoGPT, GenerateOptions } from './NanoGPTModel';
2
1
  import { ITokeniser } from './tokeniser/type';
3
2
  import { default as EE } from 'eventemitter3';
3
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
+ export interface GenerateOptions {
5
+ temperature?: number;
6
+ topK?: number;
7
+ topP?: number;
8
+ usePadding?: boolean;
9
+ attentionScores?: boolean;
10
+ includeProbabilities?: boolean;
11
+ embeddings?: boolean;
12
+ }
4
13
  export interface IGenerateOptions extends GenerateOptions {
5
14
  maxLength?: number;
6
15
  noCache?: boolean;
7
16
  }
17
+ /**
18
+ * Text generator using a NanoGPT model and a tokeniser.
19
+ * This uses the forward method of the model to generate text token by token, including options for temperature, top-k, and top-p sampling.
20
+ */
8
21
  export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
9
22
  private readonly model;
10
23
  private readonly tokeniser;
11
24
  private active;
12
- constructor(model: NanoGPT, tokeniser: ITokeniser);
25
+ private cache;
26
+ private initialPrompt;
27
+ private outputText;
28
+ private actualTokeniser;
29
+ private lastToken;
30
+ private attentionData;
31
+ private probabilitiesData;
32
+ private embeddingsData;
33
+ private tokens;
34
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
13
35
  private tokenisePrompt;
14
- private generateNoCache;
15
36
  private processResponse;
16
- private generateCache;
37
+ /** Generate logits and select a token. */
38
+ private _generateToken;
39
+ /** Generate multiple tokens in a loop and produce text */
40
+ private _generate;
41
+ reset(): void;
42
+ dispose(): void;
43
+ private initialise;
44
+ step(prompt?: string, options?: IGenerateOptions): Promise<string>;
17
45
  generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
18
46
  stop(): void;
47
+ getText(): string;
48
+ getAttentionData(): number[][][][];
49
+ getProbabilitiesData(): number[][][];
50
+ getTokens(): number[];
19
51
  }
package/dist/Generator.js CHANGED
@@ -1,15 +1,15 @@
1
- import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-BoWRt-10.js";
1
+ import { E as z } from "./index-Dwqa6Zy2.js";
2
+ import { B as A, C as L, E as C, a5 as I, t as O, k as R } from "./index-CUQrfsw_.js";
3
3
  import "./ops/cpu/attentionMask.js";
4
4
  import "./ops/webgl/attentionMask.js";
5
5
  import "./ops/grads/attentionMask.js";
6
6
  import "./ops/cpu/qkv.js";
7
7
  import "./ops/webgl/qkv.js";
8
8
  import "./ops/grads/qkv.js";
9
- import "./random_width-sZORGo5k.js";
10
- import "./register_all_kernels-BwDSRN-f.js";
9
+ import { p as _ } from "./random_width-D8Pwy_na.js";
10
+ import { t as K } from "./register_all_kernels-DUshvVWP.js";
11
11
  import "./index-Tf7vU29b.js";
12
- import "./dataset-CtdBYwjo.js";
12
+ import "./dataset-CJmEGu6D.js";
13
13
  import "./ops/cpu/rope.js";
14
14
  import "./ops/webgl/rope.js";
15
15
  import "./ops/grads/rope.js";
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
29
29
  import "./ops/cpu/scatterSub.js";
30
30
  import "./ops/webgl/scatterSub.js";
31
31
  import "./jszip.min-CjP2V1VV.js";
32
- import f from "./tokeniser/CharTokeniser.js";
32
+ import M from "./tokeniser/CharTokeniser.js";
33
33
  import "./ops/cpu/adamAdjust.js";
34
34
  import "./ops/webgl/adamAdjust.js";
35
35
  import "./ops/cpu/adamMoments.js";
@@ -37,12 +37,42 @@ import "./ops/webgl/adamMoments.js";
37
37
  import "./papaparse.min-C8l2Kvo1.js";
38
38
  import "./ops/cpu/gelu.js";
39
39
  import "./ops/webgl/gelu.js";
40
- import "./gelu-C-dPj6Ku.js";
40
+ import "./gelu-Bd3UBBxg.js";
41
41
  import "./ops/webgl/log.js";
42
- import { t as d } from "./tensor2d-wxPAnDQy.js";
43
- import { c as g } from "./concat-CsxrgovM.js";
44
- const k = [
45
- ...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
42
+ import $ from "./utilities/multinomialCPU.js";
43
+ import { r as x } from "./reshape-DEfQGSin.js";
44
+ import { t as P } from "./tensor2d-BHy261cI.js";
45
+ import { s as v } from "./softmax-DLoZWYBx.js";
46
+ import { g as q } from "./gather-C1siEkdp.js";
47
+ import { c as G } from "./concat-BmDqqFsa.js";
48
+ /**
49
+ * @license
50
+ * Copyright 2020 Google LLC. All Rights Reserved.
51
+ * Licensed under the Apache License, Version 2.0 (the "License");
52
+ * you may not use this file except in compliance with the License.
53
+ * You may obtain a copy of the License at
54
+ *
55
+ * http://www.apache.org/licenses/LICENSE-2.0
56
+ *
57
+ * Unless required by applicable law or agreed to in writing, software
58
+ * distributed under the License is distributed on an "AS IS" BASIS,
59
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60
+ * See the License for the specific language governing permissions and
61
+ * limitations under the License.
62
+ * =============================================================================
63
+ */
64
+ function N(h, t, e, i = !1) {
65
+ const o = L(h, "logits", "multinomial"), s = o.size, n = o.rank;
66
+ if (s < 2)
67
+ throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
68
+ if (n > 2)
69
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
70
+ e = e || Math.random();
71
+ const a = { logits: n === 1 ? x(o, [1, -1]) : o }, p = { numSamples: t, seed: e, normalized: i }, l = C.runKernel(I, a, p);
72
+ return n === 1 ? x(l, [l.size]) : l;
73
+ }
74
+ const S = /* @__PURE__ */ A({ multinomial_: N }), B = [
75
+ ...Array.from({ length: 95 }, (h, t) => String.fromCharCode(t + 32)),
46
76
  // ASCII
47
77
  // Spanish accented letters and punctuation
48
78
  ..."áéíóúüñ¿¡",
@@ -53,80 +83,164 @@ const k = [
53
83
  // Cyrillic letters
54
84
  ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
55
85
  ];
56
- function w(a, t) {
57
- return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
86
+ function H(h, t) {
87
+ return h.length === t ? h : h.length > t ? h.slice(0, t) : h.concat(Array(t - h.length).fill(""));
58
88
  }
59
- class pt extends u {
60
- constructor(t, o) {
61
- super(), this.model = t, this.tokeniser = o;
89
+ class Mt extends z {
90
+ constructor(t, e) {
91
+ super(), this.model = t, this.tokeniser = e, this.actualTokeniser = e;
62
92
  }
63
93
  active = !1;
64
- async tokenisePrompt(t, o) {
65
- const r = o ? await t.tokenise([o], !0) : [[t.eosToken]];
66
- return d(r, [1, r[0].length], "int32");
67
- }
68
- async generateNoCache(t, o, r) {
69
- let i = await this.tokenisePrompt(t, o), s = o || "";
70
- const n = r?.maxLength ?? 1e3;
71
- for (let m = 0; m < n && this.active; m++) {
72
- const {
73
- output: e,
74
- attention: p,
75
- probabilities: c
76
- } = await this.model.generate(i, void 0, r), h = i;
77
- i = g([i, e], 1), h.dispose();
78
- const l = await this.processResponse(t, e, p, c);
79
- if (e.dispose(), l === null)
80
- break;
81
- s += l;
82
- }
83
- return i.dispose(), s;
94
+ cache = null;
95
+ initialPrompt = null;
96
+ outputText = "";
97
+ actualTokeniser;
98
+ lastToken = -1;
99
+ attentionData = [];
100
+ probabilitiesData = [];
101
+ embeddingsData = [];
102
+ tokens = [];
103
+ async tokenisePrompt(t, e) {
104
+ const i = e ? await t.tokenise([e], !0) : [[t.eosToken]];
105
+ return P(i, [1, i[0].length], "int32");
84
106
  }
85
- async processResponse(t, o, r, i) {
86
- const s = (await o.array())[0][0];
87
- if (s === this.tokeniser.eosToken)
107
+ async processResponse(t, e, i, o) {
108
+ const s = (await e.array())[0][0];
109
+ if (this.lastToken = s, s === this.tokeniser.eosToken)
88
110
  return null;
89
111
  const n = await t.decode([s]);
90
- let m;
91
- r && (m = await Promise.all(r.map((p) => p.array().then((c) => c))), r.forEach((p) => p.dispose()));
92
- let e;
93
- return i && (e = await i.array(), i.dispose()), this.emit("tokens", [s], n, m, e), n;
112
+ if (i) {
113
+ const d = await Promise.all(i.map((a) => a.array().then((p) => p)));
114
+ i.forEach((a) => a.dispose()), this.attentionData.push(d);
115
+ }
116
+ if (o) {
117
+ const d = await o.array();
118
+ o.dispose(), this.probabilitiesData.push(d);
119
+ }
120
+ return this.tokens.push(s), this.emit("tokens", [s], n), n;
94
121
  }
95
- async generateCache(t, o, r) {
96
- let i = await this.tokenisePrompt(t, o), s = o || "";
97
- const n = new Array(this.model.config.gpt.nLayer);
98
- for (let e = 0; e < this.model.config.gpt.nLayer; e++)
99
- n[e] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
100
- const m = r?.maxLength ?? 1e3;
101
- for (let e = 0; e < m && this.active; e++) {
122
+ /** Generate logits and select a token. */
123
+ async _generateToken(t, e, i) {
124
+ const o = i?.temperature ?? 1, s = i?.topK, n = i?.topP, d = i?.usePadding ?? !1, a = {
125
+ training: !1,
126
+ attentionScores: i?.attentionScores ? {
127
+ attentionOut: []
128
+ } : void 0,
129
+ cache: e,
130
+ outputEmbeddings: i?.embeddings ?? !1
131
+ }, p = O(() => {
132
+ const r = t, m = r.shape[1], u = m <= this.model.config.blockSize ? r : r.slice(
133
+ [0, m - this.model.config.blockSize],
134
+ [r.shape[0], this.model.config.blockSize]
135
+ ), k = d ? this.model.config.blockSize - u.shape[1] : 0, b = k > 0 ? _(u, [
136
+ [0, 0],
137
+ [0, k]
138
+ ]) : u, [f] = this.model.forward(a, b), y = f.shape[1] - 1 - k, c = f.slice([0, y, 0], [f.shape[0], 1, f.shape[2]]);
139
+ return a.attentionScores?.attentionOut && a.attentionScores.attentionOut.forEach((T, E) => {
140
+ T.shape[1] !== 1 && (a.attentionScores.attentionOut[E] = R(
141
+ T.slice([0, y, 0], [T.shape[0], 1, T.shape[2]])
142
+ ), T.dispose());
143
+ }), f.dispose(), c.div(o).squeeze([1]);
144
+ });
145
+ let l;
146
+ if (n) {
147
+ const r = v(p), m = await r.array();
148
+ r.dispose();
149
+ const u = m[0].map((c, g) => ({ prob: c, index: g })).sort((c, g) => g.prob - c.prob);
150
+ let k = 0;
151
+ const b = new Array(u.length).fill(0);
152
+ for (const c of u)
153
+ if (k += c.prob, b[c.index] = c.prob, k >= n)
154
+ break;
155
+ const f = b.reduce((c, g) => c + g, 0), y = b.map((c) => c / f);
156
+ l = $(y);
157
+ } else if (s) {
158
+ const { values: r, indices: m } = K(p, s), u = S(r, 1);
159
+ l = q(m, u, 1), r.dispose(), m.dispose(), u.dispose();
160
+ } else
161
+ l = S(p, 1);
162
+ let w;
163
+ i?.includeProbabilities && (w = v(p)), a.embeddings && this.embeddingsData.push(
164
+ await Promise.all(
165
+ a.embeddings.map(async (r) => {
166
+ const m = await r.array();
167
+ return r.dispose(), m;
168
+ })
169
+ )
170
+ );
171
+ const D = l.reshape([1, 1]);
172
+ return l.dispose(), l = D, p.dispose(), { output: l, probabilities: w, attention: a.attentionScores?.attentionOut };
173
+ }
174
+ /** Generate multiple tokens in a loop and produce text */
175
+ async _generate(t) {
176
+ let e = this.lastToken >= 0 && this.cache ? P([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
177
+ const i = t?.maxLength ?? 1e3;
178
+ for (let o = 0; o < i && this.active; o++) {
102
179
  const {
103
- output: p,
104
- probabilities: c,
105
- attention: h
106
- } = await this.model.generate(i, n, {
107
- ...r,
108
- usePadding: !1
180
+ output: s,
181
+ probabilities: n,
182
+ attention: d
183
+ } = await this._generateToken(e, this.cache ? this.cache : void 0, {
184
+ ...t,
185
+ usePadding: !this.cache
109
186
  });
110
- i.dispose(), i = p;
111
- const l = await this.processResponse(t, p, h, c);
112
- if (l === null)
187
+ if (this.cache)
188
+ e.dispose(), e = s;
189
+ else {
190
+ const p = e;
191
+ e = G([e, s], 1), p.dispose();
192
+ }
193
+ const a = await this.processResponse(this.actualTokeniser, s, d, n);
194
+ if (this.cache || s.dispose(), a === null)
113
195
  break;
114
- s += l;
196
+ this.outputText += a;
115
197
  }
116
- return n.forEach((e) => {
117
- e && (e.k && e.k.dispose(), e.v && e.v.dispose());
118
- }), i.dispose(), s;
198
+ return e.dispose(), this.outputText;
199
+ }
200
+ reset() {
201
+ this.cache && (this.cache.forEach((t) => {
202
+ t && (t.k && t.k.dispose(), t.v && t.v.dispose());
203
+ }), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1, this.attentionData = [], this.probabilitiesData = [], this.tokens = [];
119
204
  }
120
- async generate(t, o) {
121
- const r = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t;
122
- this.active = !0, this.emit("start");
123
- const i = this.tokeniser.trained ? this.tokeniser : new f(w(k, this.tokeniser.vocabSize)), n = await (this.model.config.gpt.useRope && !o?.noCache ? this.generateCache(i, r, o) : this.generateNoCache(i, r, o));
124
- return this.active = !1, this.emit("stop"), n;
205
+ dispose() {
206
+ this.reset();
207
+ }
208
+ initialise(t, e) {
209
+ const i = t && t.length > this.model.config.blockSize ? t.slice(-this.model.config.blockSize) : t ?? null;
210
+ if (this.cache && e?.noCache && this.reset(), this.initialPrompt = i || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !e?.noCache && this.model.config.useRope) {
211
+ const s = new Array(this.model.config.nLayer);
212
+ for (let n = 0; n < this.model.config.nLayer; n++)
213
+ s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
214
+ this.cache = s, this.lastToken = -1;
215
+ }
216
+ const o = this.tokeniser.trained ? this.tokeniser : new M(H(B, this.tokeniser.vocabSize));
217
+ this.actualTokeniser = o;
218
+ }
219
+ async step(t, e) {
220
+ const i = { ...e, maxLength: 1 };
221
+ return this.generate(t, i);
222
+ }
223
+ async generate(t, e) {
224
+ this.initialise(t, e), this.active = !0, this.emit("start");
225
+ const o = await this._generate(e);
226
+ return this.active = !1, this.emit("stop"), o;
125
227
  }
126
228
  stop() {
127
229
  this.active = !1;
128
230
  }
231
+ getText() {
232
+ return this.outputText;
233
+ }
234
+ getAttentionData() {
235
+ return this.attentionData;
236
+ }
237
+ getProbabilitiesData() {
238
+ return this.probabilitiesData;
239
+ }
240
+ getTokens() {
241
+ return this.tokens;
242
+ }
129
243
  }
130
244
  export {
131
- pt as default
245
+ Mt as default
132
246
  };
@@ -1,10 +1,10 @@
1
- import { aq as T, ac as E, p as O, j as V, ay as B, Y as F, U, az as j } from "./index-BoWRt-10.js";
2
- import { r as $ } from "./Reshape-DH5srBP0.js";
3
- import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-BzbKo31C.js";
4
- import { t as K, m as W } from "./shared-wS99K7_n.js";
5
- import { c as _ } from "./backend_util-TE7aTPhZ.js";
6
- import { f as y } from "./gpgpu_math-DGNLNL4I.js";
7
- import { g as G, b as L } from "./kernel_funcs_utils-BYKWV8Aa.js";
1
+ import { as as T, af as E, p as O, j as V, aA as B, a0 as F, X as j, aB as K } from "./index-CUQrfsw_.js";
2
+ import { r as $ } from "./Reshape-Bo8HzP8V.js";
3
+ import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-DubwyOhW.js";
4
+ import { t as U, m as W } from "./shared-D1elLckx.js";
5
+ import { c as _ } from "./backend_util-BJ-_jSeK.js";
6
+ import { f as y } from "./gpgpu_math-TFLxaLkw.js";
7
+ import { g as G, b as L } from "./kernel_funcs_utils-P9aFa232.js";
8
8
  /**
9
9
  * @license
10
10
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -105,7 +105,7 @@ class w {
105
105
  * limitations under the License.
106
106
  * =============================================================================
107
107
  */
108
- class q {
108
+ class X {
109
109
  constructor(s, e) {
110
110
  this.variableNames = ["x"];
111
111
  const { windowSize: t, batchSize: n, inSize: l, outSize: r } = s;
@@ -229,7 +229,7 @@ class q {
229
229
  * limitations under the License.
230
230
  * =============================================================================
231
231
  */
232
- function X(a) {
232
+ function q(a) {
233
233
  const s = [];
234
234
  for (; s.length === 0 || s[s.length - 1].outSize !== 1; ) {
235
235
  const e = s.length ? s[s.length - 1].outSize : a[1], t = _(e);
@@ -242,12 +242,12 @@ function X(a) {
242
242
  return s;
243
243
  }
244
244
  function P(a, s, e, t) {
245
- const n = X(a.shape);
245
+ const n = q(a.shape);
246
246
  let l = a;
247
247
  for (let r = 0; r < n.length; r++) {
248
248
  const { inSize: i, windowSize: c, outSize: o } = n[r];
249
249
  let u, p;
250
- e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new q({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
250
+ e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new X({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
251
251
  }
252
252
  return l;
253
253
  }
@@ -459,7 +459,7 @@ function te(a) {
459
459
  const I = e.texData.get(d.dataId).values, m = new Array(i);
460
460
  for (let v = 0; v < m.length; v++)
461
461
  m[v] = n.shape[u[v]];
462
- const z = K(I, n.shape, n.dtype, u, m);
462
+ const z = U(I, n.shape, n.dtype, u, m);
463
463
  d = e.makeTensorInfo(m, n.dtype);
464
464
  const M = e.texData.get(d.dataId);
465
465
  M.values = z;
@@ -482,7 +482,7 @@ function te(a) {
482
482
  return p && e.disposeIntermediateTensorInfo(d), x;
483
483
  }
484
484
  const he = {
485
- kernelName: U,
485
+ kernelName: j,
486
486
  backendName: "webgl",
487
487
  kernelFunc: te
488
488
  };
@@ -525,7 +525,7 @@ return a / b;`, se = `
525
525
 
526
526
  return result;
527
527
  `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
528
- kernelName: j,
528
+ kernelName: K,
529
529
  backendName: "webgl",
530
530
  kernelFunc: ne
531
531
  };
@@ -1,4 +1,4 @@
1
- import { j as h, a3 as d, l as c, K as m } from "./index-BoWRt-10.js";
1
+ import { j as h, a4 as d, n as c, U as m } from "./index-CUQrfsw_.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2021 Google LLC. All Rights Reserved.
@@ -1,5 +1,5 @@
1
- import { j as c, a3 as C, l as f, K as R } from "./index-BoWRt-10.js";
2
- import { u as g, g as I, a as x, b as F, c as $, d as u, e as l, i as m } from "./gpgpu_math-DGNLNL4I.js";
1
+ import { j as c, a4 as C, n as f, U as R } from "./index-CUQrfsw_.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-TFLxaLkw.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -82,14 +82,14 @@ function v(s, t) {
82
82
  function b(s, t, i) {
83
83
  const a = [
84
84
  u(s.shape),
85
- ...l(s.shape)
85
+ ...m(s.shape)
86
86
  ], e = {
87
87
  dtype: s.dtype,
88
88
  shape: a,
89
89
  dataId: s.dataId
90
90
  }, o = [
91
91
  u(t),
92
- ...l(t)
92
+ ...m(t)
93
93
  ], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
94
94
  return { dataId: h.dataId, shape: t, dtype: h.dtype };
95
95
  }
@@ -113,7 +113,7 @@ function y(s) {
113
113
  const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = C(o, p), h = c(n);
114
114
  f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
115
115
  const d = r.texData.get(e.dataId);
116
- return d.isPacked && !m(e.shape, n) && !(d.texture !== null && m(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
116
+ return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
117
117
  }
118
118
  const U = {
119
119
  kernelName: R,
@@ -1,11 +1,11 @@
1
- import { GPTConfig } from './config';
1
+ import { GPTConfig } from './models/config';
2
2
  import { ITokeniser } from './tokeniser/type';
3
- import { default as NanoGPT, TrainingLogEntry } from './NanoGPTModel';
4
- import { SaveOptions } from './utilities/save';
3
+ import { SaveOptions } from './loader/save';
5
4
  import { default as Generator, IGenerateOptions } from './Generator';
6
5
  import { default as Trainer, ITrainerOptions } from './Trainer';
7
6
  import { default as MemoryProfiler } from './utilities/profile';
8
- import { TrainingProgress } from './training/Trainer';
7
+ import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
8
+ import { default as Model, ModelForwardAttributes } from './models/model';
9
9
  type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
10
10
  interface TeachableLLMMeta {
11
11
  name?: string;
@@ -20,12 +20,12 @@ export default class TeachableLLM {
20
20
  private _status;
21
21
  private _memoryRequirements?;
22
22
  meta: TeachableLLMMeta;
23
- constructor(tokeniser?: ITokeniser, model?: NanoGPT);
23
+ constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes>);
24
24
  get vocab(): string[];
25
25
  /** Model is fully loaded */
26
26
  get loaded(): boolean;
27
27
  get config(): GPTConfig;
28
- get model(): NanoGPT;
28
+ get model(): Model<ModelForwardAttributes>;
29
29
  get tokeniser(): ITokeniser;
30
30
  get status(): TeachableLLMStatus;
31
31
  /** Model is both ready and not busy */