@genai-fi/nanogpt 0.7.3 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. package/dist/Generator.d.ts +25 -2
  2. package/dist/Generator.js +152 -49
  3. package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-D_q39E3A.js} +13 -13
  4. package/dist/{Reshape-DvudQDvJ.js → Reshape-41YpQqEo.js} +1 -1
  5. package/dist/{Reshape-DH5srBP0.js → Reshape-Bh_jzKzV.js} +5 -5
  6. package/dist/TeachableLLM.d.ts +6 -6
  7. package/dist/TeachableLLM.js +33 -31
  8. package/dist/Trainer.d.ts +13 -2
  9. package/dist/Trainer.js +21 -12
  10. package/dist/{axis_util-BzbKo31C.js → axis_util-Did9235A.js} +3 -3
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-TE7aTPhZ.js → backend_util-yC3YH1jo.js} +58 -58
  13. package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-CUvOdOT5.js} +2 -2
  14. package/dist/checks/appendCache.d.ts +1 -0
  15. package/dist/checks/appendCache.js +22 -0
  16. package/dist/checks/attentionMask.d.ts +1 -0
  17. package/dist/checks/attentionMask.js +37 -0
  18. package/dist/checks/check.d.ts +9 -0
  19. package/dist/checks/check.js +20 -0
  20. package/dist/checks/gelu.d.ts +1 -0
  21. package/dist/checks/gelu.js +18 -0
  22. package/dist/checks/index.d.ts +19 -0
  23. package/dist/checks/index.js +21 -0
  24. package/dist/checks/normRMS.d.ts +1 -0
  25. package/dist/checks/normRMS.js +16 -0
  26. package/dist/checks/normRMSGrad.d.ts +1 -0
  27. package/dist/checks/normRMSGrad.js +12 -0
  28. package/dist/checks/qkv.d.ts +1 -0
  29. package/dist/checks/qkv.js +25 -0
  30. package/dist/checks/rope.d.ts +1 -0
  31. package/dist/checks/rope.js +21 -0
  32. package/dist/{concat-CsxrgovM.js → concat-pHiVqR3L.js} +1 -1
  33. package/dist/{dataset-CtdBYwjo.js → dataset-DPPl-iLT.js} +9 -9
  34. package/dist/{dropout-DYs5QFGQ.js → dropout-CcKSfOYE.js} +18 -18
  35. package/dist/exports_initializers-DKk7-bsx.js +16 -0
  36. package/dist/{gather-CMMy2KEG.js → gather-CPg6ZlQA.js} +1 -1
  37. package/dist/{gelu-C-dPj6Ku.js → gelu-BkcmEEyD.js} +1 -1
  38. package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-D_ODOLix.js} +26 -26
  39. package/dist/{index-BoWRt-10.js → index-DdmHGZjq.js} +659 -650
  40. package/dist/{index-CLthM0TO.js → index-evZ57wr4.js} +185 -185
  41. package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-CDfFpUab.js} +21 -21
  42. package/dist/layers/BaseLayer.d.ts +8 -13
  43. package/dist/layers/BaseLayer.js +25 -13
  44. package/dist/layers/CausalSelfAttention.d.ts +3 -2
  45. package/dist/layers/CausalSelfAttention.js +28 -28
  46. package/dist/layers/MLP.d.ts +3 -2
  47. package/dist/layers/MLP.js +16 -20
  48. package/dist/layers/PositionEmbedding.d.ts +9 -0
  49. package/dist/layers/PositionEmbedding.js +45 -0
  50. package/dist/layers/RMSNorm.d.ts +3 -2
  51. package/dist/layers/RMSNorm.js +6 -6
  52. package/dist/layers/RoPECache.d.ts +1 -1
  53. package/dist/layers/RoPECache.js +4 -4
  54. package/dist/layers/TiedEmbedding.d.ts +3 -2
  55. package/dist/layers/TiedEmbedding.js +29 -7
  56. package/dist/layers/TransformerBlock.d.ts +3 -2
  57. package/dist/layers/TransformerBlock.js +1 -1
  58. package/dist/loader/load.d.ts +2 -2
  59. package/dist/loader/loadHF.d.ts +2 -2
  60. package/dist/loader/loadTransformers.d.ts +4 -2
  61. package/dist/loader/loadTransformers.js +10 -9
  62. package/dist/loader/newZipLoad.d.ts +2 -2
  63. package/dist/loader/oldZipLoad.d.ts +2 -2
  64. package/dist/loader/oldZipLoad.js +44 -51
  65. package/dist/loader/save.d.ts +8 -0
  66. package/dist/loader/save.js +62 -0
  67. package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C8yFJfZz.js} +45 -24
  68. package/dist/main.d.ts +6 -4
  69. package/dist/main.js +24 -18
  70. package/dist/{mat_mul-8m8pfdcx.js → mat_mul-Dpy2mMRu.js} +1 -1
  71. package/dist/mod-CbibJi3D.js +27 -0
  72. package/dist/models/NanoGPTV1.d.ts +15 -0
  73. package/dist/models/NanoGPTV1.js +71 -0
  74. package/dist/{config.d.ts → models/config.d.ts} +1 -0
  75. package/dist/{config.js → models/config.js} +1 -0
  76. package/dist/models/factory.d.ts +3 -0
  77. package/dist/models/factory.js +14 -0
  78. package/dist/models/model.d.ts +26 -0
  79. package/dist/models/model.js +70 -0
  80. package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-q_Gmwyld.js} +1 -1
  81. package/dist/{ones-Dj0SDhHf.js → ones-BAqVh-eA.js} +2 -2
  82. package/dist/ops/adamAdjust.js +1 -1
  83. package/dist/ops/adamMoments.js +1 -1
  84. package/dist/ops/appendCache.js +3 -3
  85. package/dist/ops/attentionMask.js +1 -1
  86. package/dist/ops/cpu/adamAdjust.js +9 -9
  87. package/dist/ops/cpu/adamMoments.js +2 -2
  88. package/dist/ops/cpu/appendCache.js +2 -2
  89. package/dist/ops/cpu/attentionMask.js +5 -5
  90. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  91. package/dist/ops/cpu/gatherSub.js +5 -5
  92. package/dist/ops/cpu/gelu.js +1 -1
  93. package/dist/ops/cpu/matMulGelu.js +2 -2
  94. package/dist/ops/cpu/matMulMul.js +1 -1
  95. package/dist/ops/cpu/mulDropout.js +1 -1
  96. package/dist/ops/cpu/normRMS.js +1 -1
  97. package/dist/ops/cpu/qkv.js +3 -3
  98. package/dist/ops/cpu/rope.js +5 -5
  99. package/dist/ops/cpu/scatterSub.js +7 -7
  100. package/dist/ops/fusedSoftmax.js +1 -1
  101. package/dist/ops/gatherSub.js +1 -1
  102. package/dist/ops/gelu.js +2 -2
  103. package/dist/ops/grads/attentionMask.js +1 -1
  104. package/dist/ops/grads/fusedSoftmax.js +2 -2
  105. package/dist/ops/grads/gelu.js +2 -2
  106. package/dist/ops/grads/matMulGelu.js +1 -1
  107. package/dist/ops/grads/normRMS.js +1 -1
  108. package/dist/ops/grads/qkv.js +1 -1
  109. package/dist/ops/grads/rope.js +1 -1
  110. package/dist/ops/matMulGelu.js +1 -1
  111. package/dist/ops/matMulMul.js +1 -1
  112. package/dist/ops/mulDrop.js +1 -1
  113. package/dist/ops/normRMS.js +1 -1
  114. package/dist/ops/qkv.js +1 -1
  115. package/dist/ops/rope.js +4 -4
  116. package/dist/ops/scatterSub.js +1 -1
  117. package/dist/ops/webgl/adamAdjust.js +2 -2
  118. package/dist/ops/webgl/adamMoments.js +1 -1
  119. package/dist/ops/webgl/appendCache.js +1 -1
  120. package/dist/ops/webgl/attentionMask.js +1 -1
  121. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  122. package/dist/ops/webgl/gatherSub.js +1 -1
  123. package/dist/ops/webgl/gelu.js +2 -2
  124. package/dist/ops/webgl/log.js +3 -3
  125. package/dist/ops/webgl/matMulGelu.js +10 -10
  126. package/dist/ops/webgl/matMulMul.js +1 -1
  127. package/dist/ops/webgl/mulDropout.js +1 -1
  128. package/dist/ops/webgl/normRMS.js +2 -2
  129. package/dist/ops/webgl/qkv.js +1 -1
  130. package/dist/ops/webgl/rope.js +1 -1
  131. package/dist/ops/webgl/scatterSub.js +1 -1
  132. package/dist/ops/webgpu/adamAdjust.js +3 -3
  133. package/dist/ops/webgpu/adamMoments.js +3 -3
  134. package/dist/ops/webgpu/appendCache.js +3 -3
  135. package/dist/ops/webgpu/attentionMask.js +3 -3
  136. package/dist/ops/webgpu/gatherSub.js +3 -3
  137. package/dist/ops/webgpu/gelu.js +3 -3
  138. package/dist/ops/webgpu/normRMS.js +2 -2
  139. package/dist/ops/webgpu/normRMSGrad.js +5 -5
  140. package/dist/ops/webgpu/qkv.js +3 -3
  141. package/dist/ops/webgpu/rope.js +3 -3
  142. package/dist/ops/webgpu/scatterSub.js +3 -3
  143. package/dist/ops/webgpu/utils/reductions.js +4 -4
  144. package/dist/ops-542ai2vG.js +1525 -0
  145. package/dist/{random_width-sZORGo5k.js → random_width-DKGeiFuR.js} +1471 -1538
  146. package/dist/{range-CRuAh-gd.js → range-BcUvLuf5.js} +1 -1
  147. package/dist/{reciprocal-BvGAyKyu.js → reciprocal-DhDWSKiD.js} +1 -1
  148. package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-Do9VvZmo.js} +2488 -2534
  149. package/dist/{max-Ddnnb5xe.js → relu-B1AXs7p5.js} +6 -6
  150. package/dist/{reshape-CdBq1WJ6.js → reshape-WeJkT3ja.js} +1 -1
  151. package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-B7yDhiQr.js} +1 -1
  152. package/dist/{selu_util-BJEXVvjX.js → selu_util-BgUO9gHY.js} +125 -146
  153. package/dist/{shared-wS99K7_n.js → shared-CZiWmQCI.js} +1 -1
  154. package/dist/{shared-B8ztnyEk.js → shared-V6D_md-c.js} +72 -72
  155. package/dist/{sin-BeA3tsEd.js → sin-CPxad7Am.js} +1 -1
  156. package/dist/{slice-BiOsknYS.js → slice-B7jXtPnp.js} +1 -1
  157. package/dist/{softmax-Bv_6lyMX.js → softmax-BfsyI4As.js} +1 -1
  158. package/dist/{split-B-dikLRw.js → split-BPxr8_8m.js} +1 -1
  159. package/dist/{stack-B17UN2nn.js → stack-BNwLzE43.js} +1 -1
  160. package/dist/{sum-66ew2byf.js → sum-ByFINZgi.js} +3 -3
  161. package/dist/{tensor-JwS7ZYY6.js → tensor-DbqgIV9B.js} +1 -1
  162. package/dist/tensor1d-CtJq5BOv.js +27 -0
  163. package/dist/{tensor2d-wxPAnDQy.js → tensor2d-CObBWBkW.js} +1 -1
  164. package/dist/tensor3d-BOukqWwr.js +30 -0
  165. package/dist/tensor4d-DLtk7Nxh.js +30 -0
  166. package/dist/training/Adam.js +2 -2
  167. package/dist/training/AdamExt.js +1 -1
  168. package/dist/training/DatasetBuilder.js +2 -2
  169. package/dist/training/Evaluator.d.ts +2 -2
  170. package/dist/training/FullTrainer.d.ts +3 -3
  171. package/dist/training/FullTrainer.js +61 -69
  172. package/dist/training/Trainer.d.ts +15 -3
  173. package/dist/training/Trainer.js +39 -47
  174. package/dist/training/sparseCrossEntropy.js +12 -13
  175. package/dist/utilities/arrayClose.d.ts +1 -1
  176. package/dist/utilities/arrayClose.js +16 -7
  177. package/dist/utilities/dummy.d.ts +4 -4
  178. package/dist/utilities/dummy.js +13 -13
  179. package/dist/utilities/multinomialCPU.js +2 -2
  180. package/dist/utilities/parameters.d.ts +1 -1
  181. package/dist/utilities/performance.js +1 -1
  182. package/dist/utilities/profile.js +1 -1
  183. package/dist/utilities/safetensors.js +2 -2
  184. package/dist/utilities/weights.js +2 -2
  185. package/dist/{variable-BuddVFLa.js → variable-DPFOJyRG.js} +1 -1
  186. package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-Dhk9R5aG.js} +1 -1
  187. package/dist/{webgpu_util-D____QpY.js → webgpu_util-BqGnZg8t.js} +27 -27
  188. package/dist/{zeros--BdLQ3oG.js → zeros-Dnwix0p4.js} +1 -1
  189. package/package.json +2 -3
  190. package/dist/NanoGPTModel.d.ts +0 -52
  191. package/dist/NanoGPTModel.js +0 -203
  192. package/dist/TiedEmbedding-BxOerUmB.js +0 -43
  193. package/dist/ops-BFGCx8Ri.js +0 -1202
  194. package/dist/utilities/generate.d.ts +0 -3
  195. package/dist/utilities/generate.js +0 -22
  196. package/dist/utilities/save.d.ts +0 -9
  197. package/dist/utilities/save.js +0 -61
@@ -1,10 +1,23 @@
1
- import { default as NanoGPT, GenerateOptions } from './NanoGPTModel';
2
1
  import { ITokeniser } from './tokeniser/type';
3
2
  import { default as EE } from 'eventemitter3';
3
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
+ export interface GenerateOptions {
5
+ temperature?: number;
6
+ topK?: number;
7
+ topP?: number;
8
+ usePadding?: boolean;
9
+ attentionScores?: boolean;
10
+ includeProbabilities?: boolean;
11
+ embeddings?: boolean;
12
+ }
4
13
  export interface IGenerateOptions extends GenerateOptions {
5
14
  maxLength?: number;
6
15
  noCache?: boolean;
7
16
  }
17
+ /**
18
+ * Text generator using a NanoGPT model and a tokeniser.
19
+ * This uses the forward method of the model to generate text token by token, including options for temperature, top-k, and top-p sampling.
20
+ */
8
21
  export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
9
22
  private readonly model;
10
23
  private readonly tokeniser;
@@ -14,9 +27,16 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
14
27
  private outputText;
15
28
  private actualTokeniser;
16
29
  private lastToken;
17
- constructor(model: NanoGPT, tokeniser: ITokeniser);
30
+ private attentionData;
31
+ private probabilitiesData;
32
+ private embeddingsData;
33
+ private tokens;
34
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
18
35
  private tokenisePrompt;
19
36
  private processResponse;
37
+ /** Generate logits and select a token. */
38
+ private _generateToken;
39
+ /** Generate multiple tokens in a loop and produce text */
20
40
  private _generate;
21
41
  reset(): void;
22
42
  dispose(): void;
@@ -25,4 +45,7 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
25
45
  generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
26
46
  stop(): void;
27
47
  getText(): string;
48
+ getAttentionData(): number[][][][];
49
+ getProbabilitiesData(): number[][][];
50
+ getTokens(): number[];
28
51
  }
package/dist/Generator.js CHANGED
@@ -1,15 +1,15 @@
1
- import { E as l } from "./index-Dwqa6Zy2.js";
2
- import "./index-BoWRt-10.js";
1
+ import { E as z } from "./index-Dwqa6Zy2.js";
2
+ import { C as A, D as L, E as C, a6 as I, t as O, k as R } from "./index-DdmHGZjq.js";
3
3
  import "./ops/cpu/attentionMask.js";
4
4
  import "./ops/webgl/attentionMask.js";
5
5
  import "./ops/grads/attentionMask.js";
6
6
  import "./ops/cpu/qkv.js";
7
7
  import "./ops/webgl/qkv.js";
8
8
  import "./ops/grads/qkv.js";
9
- import "./random_width-sZORGo5k.js";
10
- import "./register_all_kernels-BwDSRN-f.js";
9
+ import { p as _ } from "./random_width-DKGeiFuR.js";
10
+ import { t as K } from "./register_all_kernels-Do9VvZmo.js";
11
11
  import "./index-Tf7vU29b.js";
12
- import "./dataset-CtdBYwjo.js";
12
+ import "./dataset-DPPl-iLT.js";
13
13
  import "./ops/cpu/rope.js";
14
14
  import "./ops/webgl/rope.js";
15
15
  import "./ops/grads/rope.js";
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
29
29
  import "./ops/cpu/scatterSub.js";
30
30
  import "./ops/webgl/scatterSub.js";
31
31
  import "./jszip.min-CjP2V1VV.js";
32
- import u from "./tokeniser/CharTokeniser.js";
32
+ import M from "./tokeniser/CharTokeniser.js";
33
33
  import "./ops/cpu/adamAdjust.js";
34
34
  import "./ops/webgl/adamAdjust.js";
35
35
  import "./ops/cpu/adamMoments.js";
@@ -37,12 +37,44 @@ import "./ops/webgl/adamMoments.js";
37
37
  import "./papaparse.min-C8l2Kvo1.js";
38
38
  import "./ops/cpu/gelu.js";
39
39
  import "./ops/webgl/gelu.js";
40
- import "./gelu-C-dPj6Ku.js";
40
+ import "./gelu-BkcmEEyD.js";
41
41
  import "./ops/webgl/log.js";
42
- import { t as p } from "./tensor2d-wxPAnDQy.js";
43
- import { c as f } from "./concat-CsxrgovM.js";
44
- const k = [
45
- ...Array.from({ length: 95 }, (r, t) => String.fromCharCode(t + 32)),
42
+ import "./checks/normRMS.js";
43
+ import "./checks/normRMSGrad.js";
44
+ import $ from "./utilities/multinomialCPU.js";
45
+ import { r as x } from "./reshape-WeJkT3ja.js";
46
+ import { t as P } from "./tensor2d-CObBWBkW.js";
47
+ import { s as v } from "./softmax-BfsyI4As.js";
48
+ import { g as q } from "./gather-CPg6ZlQA.js";
49
+ import { c as G } from "./concat-pHiVqR3L.js";
50
+ /**
51
+ * @license
52
+ * Copyright 2020 Google LLC. All Rights Reserved.
53
+ * Licensed under the Apache License, Version 2.0 (the "License");
54
+ * you may not use this file except in compliance with the License.
55
+ * You may obtain a copy of the License at
56
+ *
57
+ * http://www.apache.org/licenses/LICENSE-2.0
58
+ *
59
+ * Unless required by applicable law or agreed to in writing, software
60
+ * distributed under the License is distributed on an "AS IS" BASIS,
61
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62
+ * See the License for the specific language governing permissions and
63
+ * limitations under the License.
64
+ * =============================================================================
65
+ */
66
+ function N(m, t, e, i = !1) {
67
+ const o = L(m, "logits", "multinomial"), s = o.size, n = o.rank;
68
+ if (s < 2)
69
+ throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
70
+ if (n > 2)
71
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
72
+ e = e || Math.random();
73
+ const a = { logits: n === 1 ? x(o, [1, -1]) : o }, p = { numSamples: t, seed: e, normalized: i }, l = C.runKernel(I, a, p);
74
+ return n === 1 ? x(l, [l.size]) : l;
75
+ }
76
+ const S = /* @__PURE__ */ A({ multinomial_: N }), H = [
77
+ ...Array.from({ length: 95 }, (m, t) => String.fromCharCode(t + 32)),
46
78
  // ASCII
47
79
  // Spanish accented letters and punctuation
48
80
  ..."áéíóúüñ¿¡",
@@ -53,12 +85,12 @@ const k = [
53
85
  // Cyrillic letters
54
86
  ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
55
87
  ];
56
- function d(r, t) {
57
- return r.length === t ? r : r.length > t ? r.slice(0, t) : r.concat(Array(t - r.length).fill(""));
88
+ function U(m, t) {
89
+ return m.length === t ? m : m.length > t ? m.slice(0, t) : m.concat(Array(t - m.length).fill(""));
58
90
  }
59
- class nt extends l {
60
- constructor(t, i) {
61
- super(), this.model = t, this.tokeniser = i, this.actualTokeniser = i;
91
+ class qt extends z {
92
+ constructor(t, e) {
93
+ super(), this.model = t, this.tokeniser = e, this.actualTokeniser = e;
62
94
  }
63
95
  active = !1;
64
96
  cache = null;
@@ -66,71 +98,133 @@ class nt extends l {
66
98
  outputText = "";
67
99
  actualTokeniser;
68
100
  lastToken = -1;
69
- async tokenisePrompt(t, i) {
70
- const e = i ? await t.tokenise([i], !0) : [[t.eosToken]];
71
- return p(e, [1, e[0].length], "int32");
101
+ attentionData = [];
102
+ probabilitiesData = [];
103
+ embeddingsData = [];
104
+ tokens = [];
105
+ async tokenisePrompt(t, e) {
106
+ const i = e ? await t.tokenise([e], !0) : [[t.eosToken]];
107
+ return P(i, [1, i[0].length], "int32");
72
108
  }
73
- async processResponse(t, i, e, o) {
74
- const s = (await i.array())[0][0];
109
+ async processResponse(t, e, i, o) {
110
+ const s = (await e.array())[0][0];
75
111
  if (this.lastToken = s, s === this.tokeniser.eosToken)
76
112
  return null;
77
113
  const n = await t.decode([s]);
78
- let c;
79
- e && (c = await Promise.all(e.map((h) => h.array().then((m) => m))), e.forEach((h) => h.dispose()));
80
- let a;
81
- return o && (a = await o.array(), o.dispose()), this.emit("tokens", [s], n, c, a), n;
114
+ if (i) {
115
+ const d = await Promise.all(i.map((a) => a.array().then((p) => p)));
116
+ i.forEach((a) => a.dispose()), this.attentionData.push(d);
117
+ }
118
+ if (o) {
119
+ const d = await o.array();
120
+ o.dispose(), this.probabilitiesData.push(d);
121
+ }
122
+ return this.tokens.push(s), this.emit("tokens", [s], n), n;
123
+ }
124
+ /** Generate logits and select a token. */
125
+ async _generateToken(t, e, i) {
126
+ const o = i?.temperature ?? 1, s = i?.topK, n = i?.topP, d = i?.usePadding ?? !1, a = {
127
+ training: !1,
128
+ attentionScores: i?.attentionScores ? {
129
+ attentionOut: []
130
+ } : void 0,
131
+ cache: e,
132
+ outputEmbeddings: i?.embeddings ?? !1
133
+ }, p = O(() => {
134
+ const r = t, h = r.shape[1], u = h <= this.model.config.blockSize ? r : r.slice(
135
+ [0, h - this.model.config.blockSize],
136
+ [r.shape[0], this.model.config.blockSize]
137
+ ), k = d ? this.model.config.blockSize - u.shape[1] : 0, b = k > 0 ? _(u, [
138
+ [0, 0],
139
+ [0, k]
140
+ ]) : u, [f] = this.model.forward(a, b), y = f.shape[1] - 1 - k, c = f.slice([0, y, 0], [f.shape[0], 1, f.shape[2]]);
141
+ return a.attentionScores?.attentionOut && a.attentionScores.attentionOut.forEach((T, E) => {
142
+ T.shape[1] !== 1 && (a.attentionScores.attentionOut[E] = R(
143
+ T.slice([0, y, 0], [T.shape[0], 1, T.shape[2]])
144
+ ), T.dispose());
145
+ }), f.dispose(), c.div(o).squeeze([1]);
146
+ });
147
+ let l;
148
+ if (n) {
149
+ const r = v(p), h = await r.array();
150
+ r.dispose();
151
+ const u = h[0].map((c, g) => ({ prob: c, index: g })).sort((c, g) => g.prob - c.prob);
152
+ let k = 0;
153
+ const b = new Array(u.length).fill(0);
154
+ for (const c of u)
155
+ if (k += c.prob, b[c.index] = c.prob, k >= n)
156
+ break;
157
+ const f = b.reduce((c, g) => c + g, 0), y = b.map((c) => c / f);
158
+ l = $(y);
159
+ } else if (s) {
160
+ const { values: r, indices: h } = K(p, s), u = S(r, 1);
161
+ l = q(h, u, 1), r.dispose(), h.dispose(), u.dispose();
162
+ } else
163
+ l = S(p, 1);
164
+ let w;
165
+ i?.includeProbabilities && (w = v(p)), a.embeddings && this.embeddingsData.push(
166
+ await Promise.all(
167
+ a.embeddings.map(async (r) => {
168
+ const h = await r.array();
169
+ return r.dispose(), h;
170
+ })
171
+ )
172
+ );
173
+ const D = l.reshape([1, 1]);
174
+ return l.dispose(), l = D, p.dispose(), { output: l, probabilities: w, attention: a.attentionScores?.attentionOut };
82
175
  }
176
+ /** Generate multiple tokens in a loop and produce text */
83
177
  async _generate(t) {
84
- let i = this.lastToken >= 0 && this.cache ? p([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
85
- const e = t?.maxLength ?? 1e3;
86
- for (let o = 0; o < e && this.active; o++) {
178
+ let e = this.lastToken >= 0 && this.cache ? P([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
179
+ const i = t?.maxLength ?? 1e3;
180
+ for (let o = 0; o < i && this.active; o++) {
87
181
  const {
88
182
  output: s,
89
183
  probabilities: n,
90
- attention: c
91
- } = await this.model.generate(i, this.cache ? this.cache : void 0, {
184
+ attention: d
185
+ } = await this._generateToken(e, this.cache ? this.cache : void 0, {
92
186
  ...t,
93
187
  usePadding: !this.cache
94
188
  });
95
189
  if (this.cache)
96
- i.dispose(), i = s;
190
+ e.dispose(), e = s;
97
191
  else {
98
- const h = i;
99
- i = f([i, s], 1), h.dispose();
192
+ const p = e;
193
+ e = G([e, s], 1), p.dispose();
100
194
  }
101
- const a = await this.processResponse(this.actualTokeniser, s, c, n);
195
+ const a = await this.processResponse(this.actualTokeniser, s, d, n);
102
196
  if (this.cache || s.dispose(), a === null)
103
197
  break;
104
198
  this.outputText += a;
105
199
  }
106
- return i.dispose(), this.outputText;
200
+ return e.dispose(), this.outputText;
107
201
  }
108
202
  reset() {
109
203
  this.cache && (this.cache.forEach((t) => {
110
204
  t && (t.k && t.k.dispose(), t.v && t.v.dispose());
111
- }), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1;
205
+ }), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1, this.attentionData = [], this.probabilitiesData = [], this.tokens = [];
112
206
  }
113
207
  dispose() {
114
208
  this.reset();
115
209
  }
116
- initialise(t, i) {
117
- const e = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t ?? null;
118
- if (this.cache && i?.noCache && this.reset(), this.initialPrompt = e || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !i?.noCache && this.model.config.gpt.useRope) {
119
- const s = new Array(this.model.config.gpt.nLayer);
120
- for (let n = 0; n < this.model.config.gpt.nLayer; n++)
210
+ initialise(t, e) {
211
+ const i = t && t.length > this.model.config.blockSize ? t.slice(-this.model.config.blockSize) : t ?? null;
212
+ if (this.cache && e?.noCache && this.reset(), this.initialPrompt = i || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !e?.noCache && this.model.config.useRope) {
213
+ const s = new Array(this.model.config.nLayer);
214
+ for (let n = 0; n < this.model.config.nLayer; n++)
121
215
  s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
122
216
  this.cache = s, this.lastToken = -1;
123
217
  }
124
- const o = this.tokeniser.trained ? this.tokeniser : new u(d(k, this.tokeniser.vocabSize));
218
+ const o = this.tokeniser.trained ? this.tokeniser : new M(U(H, this.tokeniser.vocabSize));
125
219
  this.actualTokeniser = o;
126
220
  }
127
- async step(t, i) {
128
- const e = { ...i, maxLength: 1 };
129
- return this.generate(t, e);
221
+ async step(t, e) {
222
+ const i = { ...e, maxLength: 1 };
223
+ return this.generate(t, i);
130
224
  }
131
- async generate(t, i) {
132
- this.initialise(t, i), this.active = !0, this.emit("start");
133
- const o = await this._generate(i);
225
+ async generate(t, e) {
226
+ this.initialise(t, e), this.active = !0, this.emit("start");
227
+ const o = await this._generate(e);
134
228
  return this.active = !1, this.emit("stop"), o;
135
229
  }
136
230
  stop() {
@@ -139,7 +233,16 @@ class nt extends l {
139
233
  getText() {
140
234
  return this.outputText;
141
235
  }
236
+ getAttentionData() {
237
+ return this.attentionData;
238
+ }
239
+ getProbabilitiesData() {
240
+ return this.probabilitiesData;
241
+ }
242
+ getTokens() {
243
+ return this.tokens;
244
+ }
142
245
  }
143
246
  export {
144
- nt as default
247
+ qt as default
145
248
  };
@@ -1,10 +1,10 @@
1
- import { aq as T, ac as E, p as O, j as V, ay as B, Y as F, U, az as j } from "./index-BoWRt-10.js";
2
- import { r as $ } from "./Reshape-DH5srBP0.js";
3
- import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-BzbKo31C.js";
4
- import { t as K, m as W } from "./shared-wS99K7_n.js";
5
- import { c as _ } from "./backend_util-TE7aTPhZ.js";
6
- import { f as y } from "./gpgpu_math-DGNLNL4I.js";
7
- import { g as G, b as L } from "./kernel_funcs_utils-BYKWV8Aa.js";
1
+ import { aq as T, ag as E, p as O, j as V, aB as B, a1 as F, ah as j, aC as K } from "./index-DdmHGZjq.js";
2
+ import { r as $ } from "./Reshape-Bh_jzKzV.js";
3
+ import { g as A, a as C, b as k, c as N, e as R } from "./axis_util-Did9235A.js";
4
+ import { t as U, m as W } from "./shared-CZiWmQCI.js";
5
+ import { c as _ } from "./backend_util-yC3YH1jo.js";
6
+ import { f as y } from "./gpgpu_math-D_ODOLix.js";
7
+ import { g as G, b as L } from "./kernel_funcs_utils-CDfFpUab.js";
8
8
  /**
9
9
  * @license
10
10
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -381,7 +381,7 @@ function Q(a, s, e, t) {
381
381
  let i = r;
382
382
  const c = A(i, l), o = c != null;
383
383
  let u = a;
384
- o && (u = D(a, c, t), i = k(i.length, l)), C("sum", i, l);
384
+ o && (u = D(a, c, t), i = C(i.length, l)), k("sum", i, l);
385
385
  const [p, h] = N(u.shape, i);
386
386
  let d = p;
387
387
  e && (d = R(p, r));
@@ -459,15 +459,15 @@ function te(a) {
459
459
  const I = e.texData.get(d.dataId).values, m = new Array(i);
460
460
  for (let v = 0; v < m.length; v++)
461
461
  m[v] = n.shape[u[v]];
462
- const z = K(I, n.shape, n.dtype, u, m);
462
+ const z = U(I, n.shape, n.dtype, u, m);
463
463
  d = e.makeTensorInfo(m, n.dtype);
464
464
  const M = e.texData.get(d.dataId);
465
465
  M.values = z;
466
466
  } else
467
467
  d = D(n, u, e);
468
- o = k(o.length, i);
468
+ o = C(o.length, i);
469
469
  }
470
- C("max", o, i);
470
+ k("max", o, i);
471
471
  const [f, S] = N(d.shape, o);
472
472
  let g = f;
473
473
  r && (g = R(f, c));
@@ -482,7 +482,7 @@ function te(a) {
482
482
  return p && e.disposeIntermediateTensorInfo(d), x;
483
483
  }
484
484
  const he = {
485
- kernelName: U,
485
+ kernelName: j,
486
486
  backendName: "webgl",
487
487
  kernelFunc: te
488
488
  };
@@ -525,7 +525,7 @@ return a / b;`, se = `
525
525
 
526
526
  return result;
527
527
  `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
528
- kernelName: j,
528
+ kernelName: K,
529
529
  backendName: "webgl",
530
530
  kernelFunc: ne
531
531
  };
@@ -1,4 +1,4 @@
1
- import { j as h, a3 as d, l as c, K as m } from "./index-BoWRt-10.js";
1
+ import { j as h, a5 as d, n as c, V as m } from "./index-DdmHGZjq.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2021 Google LLC. All Rights Reserved.
@@ -1,5 +1,5 @@
1
- import { j as c, a3 as C, l as f, K as R } from "./index-BoWRt-10.js";
2
- import { u as g, g as I, a as x, b as F, c as $, d as u, e as l, i as m } from "./gpgpu_math-DGNLNL4I.js";
1
+ import { j as c, a5 as C, n as f, V as R } from "./index-DdmHGZjq.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-D_ODOLix.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -82,14 +82,14 @@ function v(s, t) {
82
82
  function b(s, t, i) {
83
83
  const a = [
84
84
  u(s.shape),
85
- ...l(s.shape)
85
+ ...m(s.shape)
86
86
  ], e = {
87
87
  dtype: s.dtype,
88
88
  shape: a,
89
89
  dataId: s.dataId
90
90
  }, o = [
91
91
  u(t),
92
- ...l(t)
92
+ ...m(t)
93
93
  ], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
94
94
  return { dataId: h.dataId, shape: t, dtype: h.dtype };
95
95
  }
@@ -113,7 +113,7 @@ function y(s) {
113
113
  const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = C(o, p), h = c(n);
114
114
  f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
115
115
  const d = r.texData.get(e.dataId);
116
- return d.isPacked && !m(e.shape, n) && !(d.texture !== null && m(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
116
+ return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
117
117
  }
118
118
  const U = {
119
119
  kernelName: R,
@@ -1,11 +1,11 @@
1
- import { GPTConfig } from './config';
1
+ import { GPTConfig } from './models/config';
2
2
  import { ITokeniser } from './tokeniser/type';
3
- import { default as NanoGPT, TrainingLogEntry } from './NanoGPTModel';
4
- import { SaveOptions } from './utilities/save';
3
+ import { SaveOptions } from './loader/save';
5
4
  import { default as Generator, IGenerateOptions } from './Generator';
6
5
  import { default as Trainer, ITrainerOptions } from './Trainer';
7
6
  import { default as MemoryProfiler } from './utilities/profile';
8
- import { TrainingProgress } from './training/Trainer';
7
+ import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
8
+ import { default as Model, ModelForwardAttributes } from './models/model';
9
9
  type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
10
10
  interface TeachableLLMMeta {
11
11
  name?: string;
@@ -20,12 +20,12 @@ export default class TeachableLLM {
20
20
  private _status;
21
21
  private _memoryRequirements?;
22
22
  meta: TeachableLLMMeta;
23
- constructor(tokeniser?: ITokeniser, model?: NanoGPT);
23
+ constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes>);
24
24
  get vocab(): string[];
25
25
  /** Model is fully loaded */
26
26
  get loaded(): boolean;
27
27
  get config(): GPTConfig;
28
- get model(): NanoGPT;
28
+ get model(): Model<ModelForwardAttributes>;
29
29
  get tokeniser(): ITokeniser;
30
30
  get status(): TeachableLLMStatus;
31
31
  /** Model is both ready and not busy */
@@ -1,30 +1,21 @@
1
- import { defaultConfig as _ } from "./config.js";
2
- import f from "./NanoGPTModel.js";
3
- import { saveModel as d } from "./utilities/save.js";
4
- import { loadModel as l } from "./loader/load.js";
1
+ import { defaultConfig as d } from "./models/config.js";
2
+ import { saveModel as l } from "./loader/save.js";
3
+ import { loadModel as _ } from "./loader/load.js";
5
4
  import u from "./Generator.js";
6
- import p from "./Trainer.js";
7
- import { E as c } from "./index-Dwqa6Zy2.js";
5
+ import f from "./Trainer.js";
6
+ import { E as p } from "./index-Dwqa6Zy2.js";
8
7
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
9
- import g from "./tokeniser/CharTokeniser.js";
10
- import k from "./tokeniser/bpe.js";
11
- import "./papaparse.min-C8l2Kvo1.js";
12
- import "./index-Tf7vU29b.js";
13
- import "./jszip.min-CjP2V1VV.js";
14
- import "./index-BoWRt-10.js";
15
- import "./ops/cpu/scatterSub.js";
16
- import "./ops/webgl/scatterSub.js";
17
- import "./ops/cpu/gatherSub.js";
18
- import "./ops/webgl/gatherSub.js";
8
+ import "./index-DdmHGZjq.js";
19
9
  import "./ops/cpu/attentionMask.js";
20
10
  import "./ops/webgl/attentionMask.js";
21
11
  import "./ops/grads/attentionMask.js";
22
12
  import "./ops/cpu/qkv.js";
23
13
  import "./ops/webgl/qkv.js";
24
14
  import "./ops/grads/qkv.js";
25
- import "./random_width-sZORGo5k.js";
26
- import "./register_all_kernels-BwDSRN-f.js";
27
- import "./dataset-CtdBYwjo.js";
15
+ import "./random_width-DKGeiFuR.js";
16
+ import "./register_all_kernels-Do9VvZmo.js";
17
+ import "./index-Tf7vU29b.js";
18
+ import "./dataset-DPPl-iLT.js";
28
19
  import "./ops/cpu/rope.js";
29
20
  import "./ops/webgl/rope.js";
30
21
  import "./ops/grads/rope.js";
@@ -36,20 +27,31 @@ import "./ops/grads/fusedSoftmax.js";
36
27
  import "./ops/cpu/matMulGelu.js";
37
28
  import "./ops/webgl/matMulGelu.js";
38
29
  import "./ops/grads/matMulGelu.js";
39
- import "./ops/cpu/gelu.js";
40
- import "./ops/webgl/gelu.js";
41
- import "./gelu-C-dPj6Ku.js";
42
30
  import "./ops/cpu/normRMS.js";
43
31
  import "./ops/webgl/normRMS.js";
44
32
  import "./ops/grads/normRMS.js";
33
+ import "./ops/cpu/gatherSub.js";
34
+ import "./ops/webgl/gatherSub.js";
35
+ import "./ops/cpu/scatterSub.js";
36
+ import "./ops/webgl/scatterSub.js";
37
+ import c from "./tokeniser/CharTokeniser.js";
38
+ import g from "./tokeniser/bpe.js";
39
+ import "./papaparse.min-C8l2Kvo1.js";
40
+ import "./jszip.min-CjP2V1VV.js";
41
+ import "./ops/cpu/gelu.js";
42
+ import "./ops/webgl/gelu.js";
43
+ import "./gelu-BkcmEEyD.js";
45
44
  import "./ops/webgl/log.js";
46
45
  import "./ops/cpu/adamMoments.js";
47
46
  import "./ops/webgl/adamMoments.js";
48
47
  import "./ops/cpu/adamAdjust.js";
49
48
  import "./ops/webgl/adamAdjust.js";
50
- import w from "./utilities/profile.js";
49
+ import "./checks/normRMS.js";
50
+ import "./checks/normRMSGrad.js";
51
+ import k from "./utilities/profile.js";
52
+ import w from "./models/factory.js";
51
53
  class a {
52
- ee = new c();
54
+ ee = new p();
53
55
  _config;
54
56
  _model;
55
57
  _tokeniser;
@@ -69,7 +71,7 @@ class a {
69
71
  get config() {
70
72
  if (!this._config)
71
73
  throw new Error("configuration_not_initialized.");
72
- return this._config.gpt;
74
+ return this._config;
73
75
  }
74
76
  get model() {
75
77
  if (!this._model)
@@ -101,14 +103,14 @@ class a {
101
103
  saveModel(t) {
102
104
  if (!this._model || !this._tokeniser)
103
105
  throw new Error("model_or_tokeniser_not_initialized.");
104
- return d(this._model, this._tokeniser, {
106
+ return l(this._model, this._tokeniser, {
105
107
  ...t,
106
108
  name: t?.name || this.meta.name
107
109
  });
108
110
  }
109
111
  static loadModel(t) {
110
112
  const e = new a();
111
- return l(t).then(({ model: r, tokeniser: o, name: s }) => {
113
+ return _(t).then(({ model: r, tokeniser: o, name: s }) => {
112
114
  e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
113
115
  e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
114
116
  }).catch((i) => {
@@ -119,7 +121,7 @@ class a {
119
121
  }), e;
120
122
  }
121
123
  static create(t, e = {}) {
122
- const r = { ..._, ...e }, o = t === "char" ? new g(r.vocabSize) : new k(r.vocabSize), s = new f(r), i = new a(o, s);
124
+ const r = { ...d, ...e }, o = t === "char" ? new c(r.vocabSize) : new g(r.vocabSize), s = w(r), i = new a(o, s);
123
125
  return i.setStatus("warmup"), m(s).then((n) => {
124
126
  i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
125
127
  h === "trained" && i.setStatus("ready");
@@ -138,9 +140,9 @@ class a {
138
140
  if (t) {
139
141
  if (!this._config)
140
142
  return;
141
- this._config.layerConfig.profiler || (this._config.layerConfig.profiler = new w());
143
+ this.model.getProfiler() || this.model.setProfiler(new k());
142
144
  } else
143
- this._config?.layerConfig.profiler && (this._config.layerConfig.profiler = void 0);
145
+ this.model.getProfiler() && this.model.setProfiler(null);
144
146
  }
145
147
  getNumParams() {
146
148
  return this._model ? this._model.getNumParams() : 0;
@@ -148,7 +150,7 @@ class a {
148
150
  trainer() {
149
151
  if (!this._model || !this._tokeniser)
150
152
  throw new Error("model_or_tokeniser_not_initialized.");
151
- const t = new p(this._model, this._tokeniser);
153
+ const t = new f(this._model, this._tokeniser);
152
154
  return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
153
155
  const o = this.ee.listeners("trainStep");
154
156
  for (const s of o)
package/dist/Trainer.d.ts CHANGED
@@ -1,6 +1,7 @@
1
- import { default as NanoGPT } from './NanoGPTModel';
2
1
  import { ITokeniser } from './tokeniser/type';
3
2
  import { default as EE } from 'eventemitter3';
3
+ import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
4
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
5
  export interface ITrainerOptions {
5
6
  batchSize?: number;
6
7
  learningRate?: number;
@@ -10,6 +11,11 @@ export interface ITrainerOptions {
10
11
  prompt?: string;
11
12
  validationSplit?: number;
12
13
  advancedMetrics?: boolean;
14
+ gradientCheckpointing?: boolean;
15
+ }
16
+ interface ExtendedTrainingProgress extends TrainingProgress {
17
+ progress: number;
18
+ remaining: number;
13
19
  }
14
20
  export default class Trainer extends EE<'start' | 'stop' | 'log'> {
15
21
  private trainer;
@@ -17,10 +23,15 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
17
23
  private trainDataset?;
18
24
  private validationDataset?;
19
25
  private totalSamples;
20
- constructor(model: NanoGPT, tokeniser: ITokeniser);
26
+ private log;
27
+ private progress;
28
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
21
29
  stop(): void;
22
30
  reset(): void;
23
31
  prepare(text: string[], options?: ITrainerOptions): Promise<void>;
24
32
  train(options?: ITrainerOptions): Promise<void>;
25
33
  step(options?: ITrainerOptions): Promise<void>;
34
+ getLog(): TrainingLogEntry[];
35
+ getProgress(): ExtendedTrainingProgress | null;
26
36
  }
37
+ export {};