@genai-fi/nanogpt 0.7.3 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. package/dist/Generator.d.ts +25 -2
  2. package/dist/Generator.js +150 -49
  3. package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-N8TpOMYv.js} +14 -14
  4. package/dist/{Reshape-DvudQDvJ.js → Reshape-B-lWQRnF.js} +1 -1
  5. package/dist/{Reshape-DH5srBP0.js → Reshape-Bo8HzP8V.js} +5 -5
  6. package/dist/TeachableLLM.d.ts +6 -6
  7. package/dist/TeachableLLM.js +31 -31
  8. package/dist/Trainer.d.ts +13 -2
  9. package/dist/Trainer.js +21 -12
  10. package/dist/{axis_util-BzbKo31C.js → axis_util-DubwyOhW.js} +3 -3
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-TE7aTPhZ.js → backend_util-BJ-_jSeK.js} +46 -46
  13. package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-BYfCp5iL.js} +2 -2
  14. package/dist/{concat-CsxrgovM.js → concat-BmDqqFsa.js} +1 -1
  15. package/dist/{dataset-CtdBYwjo.js → dataset-CJmEGu6D.js} +5 -5
  16. package/dist/{dropout-DYs5QFGQ.js → dropout-sx0sjVAT.js} +8 -8
  17. package/dist/exports_initializers-DAKM8UO9.js +16 -0
  18. package/dist/{gather-CMMy2KEG.js → gather-C1siEkdp.js} +1 -1
  19. package/dist/{gelu-C-dPj6Ku.js → gelu-Bd3UBBxg.js} +1 -1
  20. package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-TFLxaLkw.js} +26 -26
  21. package/dist/{index-CLthM0TO.js → index-BaPo_0H8.js} +185 -185
  22. package/dist/{index-BoWRt-10.js → index-CUQrfsw_.js} +266 -265
  23. package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-P9aFa232.js} +9 -9
  24. package/dist/layers/BaseLayer.d.ts +8 -13
  25. package/dist/layers/BaseLayer.js +25 -13
  26. package/dist/layers/CausalSelfAttention.d.ts +3 -2
  27. package/dist/layers/CausalSelfAttention.js +28 -28
  28. package/dist/layers/MLP.d.ts +3 -2
  29. package/dist/layers/MLP.js +16 -20
  30. package/dist/layers/PositionEmbedding.d.ts +9 -0
  31. package/dist/layers/PositionEmbedding.js +45 -0
  32. package/dist/layers/RMSNorm.d.ts +3 -2
  33. package/dist/layers/RMSNorm.js +6 -6
  34. package/dist/layers/RoPECache.d.ts +1 -1
  35. package/dist/layers/RoPECache.js +4 -4
  36. package/dist/layers/TiedEmbedding.d.ts +3 -2
  37. package/dist/layers/TiedEmbedding.js +29 -7
  38. package/dist/layers/TransformerBlock.d.ts +3 -2
  39. package/dist/layers/TransformerBlock.js +1 -1
  40. package/dist/loader/load.d.ts +2 -2
  41. package/dist/loader/loadHF.d.ts +2 -2
  42. package/dist/loader/loadTransformers.d.ts +4 -2
  43. package/dist/loader/loadTransformers.js +10 -9
  44. package/dist/loader/newZipLoad.d.ts +2 -2
  45. package/dist/loader/oldZipLoad.d.ts +2 -2
  46. package/dist/loader/oldZipLoad.js +42 -51
  47. package/dist/loader/save.d.ts +8 -0
  48. package/dist/loader/save.js +62 -0
  49. package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C142qZqY.js} +14 -14
  50. package/dist/main.d.ts +5 -4
  51. package/dist/main.js +22 -18
  52. package/dist/{mat_mul-8m8pfdcx.js → mat_mul-DMkduNJu.js} +1 -1
  53. package/dist/{max-Ddnnb5xe.js → max-B3JOcNGb.js} +1 -1
  54. package/dist/mod-uUuj4gSb.js +27 -0
  55. package/dist/models/NanoGPTV1.d.ts +15 -0
  56. package/dist/models/NanoGPTV1.js +71 -0
  57. package/dist/{config.d.ts → models/config.d.ts} +1 -0
  58. package/dist/{config.js → models/config.js} +1 -0
  59. package/dist/models/factory.d.ts +3 -0
  60. package/dist/models/factory.js +14 -0
  61. package/dist/models/model.d.ts +26 -0
  62. package/dist/models/model.js +68 -0
  63. package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-Cm2gw-c8.js} +1 -1
  64. package/dist/{ones-Dj0SDhHf.js → ones-ZdgQGBCP.js} +2 -2
  65. package/dist/ops/adamAdjust.js +1 -1
  66. package/dist/ops/adamMoments.js +1 -1
  67. package/dist/ops/appendCache.js +3 -3
  68. package/dist/ops/attentionMask.js +1 -1
  69. package/dist/ops/cpu/adamAdjust.js +9 -9
  70. package/dist/ops/cpu/adamMoments.js +2 -2
  71. package/dist/ops/cpu/appendCache.js +2 -2
  72. package/dist/ops/cpu/attentionMask.js +5 -5
  73. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  74. package/dist/ops/cpu/gatherSub.js +3 -3
  75. package/dist/ops/cpu/gelu.js +1 -1
  76. package/dist/ops/cpu/matMulGelu.js +2 -2
  77. package/dist/ops/cpu/matMulMul.js +1 -1
  78. package/dist/ops/cpu/mulDropout.js +1 -1
  79. package/dist/ops/cpu/normRMS.js +1 -1
  80. package/dist/ops/cpu/qkv.js +3 -3
  81. package/dist/ops/cpu/rope.js +5 -5
  82. package/dist/ops/cpu/scatterSub.js +11 -11
  83. package/dist/ops/fusedSoftmax.js +1 -1
  84. package/dist/ops/gatherSub.js +1 -1
  85. package/dist/ops/gelu.js +2 -2
  86. package/dist/ops/grads/attentionMask.js +1 -1
  87. package/dist/ops/grads/fusedSoftmax.js +2 -2
  88. package/dist/ops/grads/gelu.js +2 -2
  89. package/dist/ops/grads/matMulGelu.js +1 -1
  90. package/dist/ops/grads/normRMS.js +1 -1
  91. package/dist/ops/grads/qkv.js +1 -1
  92. package/dist/ops/grads/rope.js +1 -1
  93. package/dist/ops/matMulGelu.js +1 -1
  94. package/dist/ops/matMulMul.js +1 -1
  95. package/dist/ops/mulDrop.js +1 -1
  96. package/dist/ops/normRMS.js +1 -1
  97. package/dist/ops/qkv.js +1 -1
  98. package/dist/ops/rope.js +4 -4
  99. package/dist/ops/scatterSub.js +1 -1
  100. package/dist/ops/webgl/adamAdjust.js +2 -2
  101. package/dist/ops/webgl/adamMoments.js +1 -1
  102. package/dist/ops/webgl/appendCache.js +1 -1
  103. package/dist/ops/webgl/attentionMask.js +1 -1
  104. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  105. package/dist/ops/webgl/gatherSub.js +1 -1
  106. package/dist/ops/webgl/gelu.js +2 -2
  107. package/dist/ops/webgl/log.js +3 -3
  108. package/dist/ops/webgl/matMulGelu.js +10 -10
  109. package/dist/ops/webgl/matMulMul.js +1 -1
  110. package/dist/ops/webgl/mulDropout.js +1 -1
  111. package/dist/ops/webgl/normRMS.js +2 -2
  112. package/dist/ops/webgl/qkv.js +1 -1
  113. package/dist/ops/webgl/rope.js +1 -1
  114. package/dist/ops/webgl/scatterSub.js +1 -1
  115. package/dist/ops/webgpu/adamAdjust.js +3 -3
  116. package/dist/ops/webgpu/adamMoments.js +3 -3
  117. package/dist/ops/webgpu/appendCache.js +3 -3
  118. package/dist/ops/webgpu/attentionMask.js +3 -3
  119. package/dist/ops/webgpu/gatherSub.js +3 -3
  120. package/dist/ops/webgpu/gelu.js +3 -3
  121. package/dist/ops/webgpu/normRMS.js +2 -2
  122. package/dist/ops/webgpu/normRMSGrad.js +5 -5
  123. package/dist/ops/webgpu/qkv.js +3 -3
  124. package/dist/ops/webgpu/rope.js +3 -3
  125. package/dist/ops/webgpu/scatterSub.js +3 -3
  126. package/dist/ops/webgpu/utils/reductions.js +4 -4
  127. package/dist/{ops-BFGCx8Ri.js → ops-C_1K_-35.js} +103 -103
  128. package/dist/{random_width-sZORGo5k.js → random_width-D8Pwy_na.js} +136 -136
  129. package/dist/{range-CRuAh-gd.js → range-LVHrSLdi.js} +1 -1
  130. package/dist/{reciprocal-BvGAyKyu.js → reciprocal-CaR9e67G.js} +1 -1
  131. package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-DUshvVWP.js} +2026 -2049
  132. package/dist/{reshape-CdBq1WJ6.js → reshape-DEfQGSin.js} +1 -1
  133. package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-CUPPNLaA.js} +1 -1
  134. package/dist/{selu_util-BJEXVvjX.js → selu_util-8vv5JxQV.js} +3 -3
  135. package/dist/{shared-B8ztnyEk.js → shared-CkNorDcU.js} +83 -83
  136. package/dist/{shared-wS99K7_n.js → shared-D1elLckx.js} +1 -1
  137. package/dist/{sin-BeA3tsEd.js → sin-D2CKKmyR.js} +1 -1
  138. package/dist/{slice-BiOsknYS.js → slice-BnyE-M_7.js} +1 -1
  139. package/dist/{softmax-Bv_6lyMX.js → softmax-DLoZWYBx.js} +1 -1
  140. package/dist/{split-B-dikLRw.js → split-By_n4TKP.js} +1 -1
  141. package/dist/{stack-B17UN2nn.js → stack-DkdFLq37.js} +1 -1
  142. package/dist/{sum-66ew2byf.js → sum-l_0SqM4h.js} +3 -3
  143. package/dist/{tensor-JwS7ZYY6.js → tensor-BAQdLqoU.js} +1 -1
  144. package/dist/{tensor2d-wxPAnDQy.js → tensor2d-BHy261cI.js} +1 -1
  145. package/dist/training/Adam.js +2 -2
  146. package/dist/training/AdamExt.js +1 -1
  147. package/dist/training/DatasetBuilder.js +2 -2
  148. package/dist/training/Evaluator.d.ts +2 -2
  149. package/dist/training/FullTrainer.d.ts +3 -3
  150. package/dist/training/FullTrainer.js +61 -69
  151. package/dist/training/Trainer.d.ts +15 -3
  152. package/dist/training/Trainer.js +39 -47
  153. package/dist/training/sparseCrossEntropy.js +9 -9
  154. package/dist/utilities/dummy.d.ts +4 -4
  155. package/dist/utilities/dummy.js +13 -13
  156. package/dist/utilities/multinomialCPU.js +2 -2
  157. package/dist/utilities/parameters.d.ts +1 -1
  158. package/dist/utilities/performance.js +1 -1
  159. package/dist/utilities/profile.js +1 -1
  160. package/dist/utilities/safetensors.js +2 -2
  161. package/dist/utilities/weights.js +2 -2
  162. package/dist/{variable-BuddVFLa.js → variable-C9hihzDB.js} +1 -1
  163. package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-dFEVbDPL.js} +1 -1
  164. package/dist/{webgpu_util-D____QpY.js → webgpu_util-DLImlSc6.js} +27 -27
  165. package/dist/{zeros--BdLQ3oG.js → zeros-VZ72lWXM.js} +1 -1
  166. package/package.json +2 -3
  167. package/dist/NanoGPTModel.d.ts +0 -52
  168. package/dist/NanoGPTModel.js +0 -203
  169. package/dist/TiedEmbedding-BxOerUmB.js +0 -43
  170. package/dist/utilities/generate.d.ts +0 -3
  171. package/dist/utilities/generate.js +0 -22
  172. package/dist/utilities/save.d.ts +0 -9
  173. package/dist/utilities/save.js +0 -61
@@ -1,10 +1,23 @@
1
- import { default as NanoGPT, GenerateOptions } from './NanoGPTModel';
2
1
  import { ITokeniser } from './tokeniser/type';
3
2
  import { default as EE } from 'eventemitter3';
3
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
+ export interface GenerateOptions {
5
+ temperature?: number;
6
+ topK?: number;
7
+ topP?: number;
8
+ usePadding?: boolean;
9
+ attentionScores?: boolean;
10
+ includeProbabilities?: boolean;
11
+ embeddings?: boolean;
12
+ }
4
13
  export interface IGenerateOptions extends GenerateOptions {
5
14
  maxLength?: number;
6
15
  noCache?: boolean;
7
16
  }
17
+ /**
18
+ * Text generator using a NanoGPT model and a tokeniser.
19
+ * This uses the forward method of the model to generate text token by token, including options for temperature, top-k, and top-p sampling.
20
+ */
8
21
  export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
9
22
  private readonly model;
10
23
  private readonly tokeniser;
@@ -14,9 +27,16 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
14
27
  private outputText;
15
28
  private actualTokeniser;
16
29
  private lastToken;
17
- constructor(model: NanoGPT, tokeniser: ITokeniser);
30
+ private attentionData;
31
+ private probabilitiesData;
32
+ private embeddingsData;
33
+ private tokens;
34
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
18
35
  private tokenisePrompt;
19
36
  private processResponse;
37
+ /** Generate logits and select a token. */
38
+ private _generateToken;
39
+ /** Generate multiple tokens in a loop and produce text */
20
40
  private _generate;
21
41
  reset(): void;
22
42
  dispose(): void;
@@ -25,4 +45,7 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
25
45
  generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
26
46
  stop(): void;
27
47
  getText(): string;
48
+ getAttentionData(): number[][][][];
49
+ getProbabilitiesData(): number[][][];
50
+ getTokens(): number[];
28
51
  }
package/dist/Generator.js CHANGED
@@ -1,15 +1,15 @@
1
- import { E as l } from "./index-Dwqa6Zy2.js";
2
- import "./index-BoWRt-10.js";
1
+ import { E as z } from "./index-Dwqa6Zy2.js";
2
+ import { B as A, C as L, E as C, a5 as I, t as O, k as R } from "./index-CUQrfsw_.js";
3
3
  import "./ops/cpu/attentionMask.js";
4
4
  import "./ops/webgl/attentionMask.js";
5
5
  import "./ops/grads/attentionMask.js";
6
6
  import "./ops/cpu/qkv.js";
7
7
  import "./ops/webgl/qkv.js";
8
8
  import "./ops/grads/qkv.js";
9
- import "./random_width-sZORGo5k.js";
10
- import "./register_all_kernels-BwDSRN-f.js";
9
+ import { p as _ } from "./random_width-D8Pwy_na.js";
10
+ import { t as K } from "./register_all_kernels-DUshvVWP.js";
11
11
  import "./index-Tf7vU29b.js";
12
- import "./dataset-CtdBYwjo.js";
12
+ import "./dataset-CJmEGu6D.js";
13
13
  import "./ops/cpu/rope.js";
14
14
  import "./ops/webgl/rope.js";
15
15
  import "./ops/grads/rope.js";
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
29
29
  import "./ops/cpu/scatterSub.js";
30
30
  import "./ops/webgl/scatterSub.js";
31
31
  import "./jszip.min-CjP2V1VV.js";
32
- import u from "./tokeniser/CharTokeniser.js";
32
+ import M from "./tokeniser/CharTokeniser.js";
33
33
  import "./ops/cpu/adamAdjust.js";
34
34
  import "./ops/webgl/adamAdjust.js";
35
35
  import "./ops/cpu/adamMoments.js";
@@ -37,12 +37,42 @@ import "./ops/webgl/adamMoments.js";
37
37
  import "./papaparse.min-C8l2Kvo1.js";
38
38
  import "./ops/cpu/gelu.js";
39
39
  import "./ops/webgl/gelu.js";
40
- import "./gelu-C-dPj6Ku.js";
40
+ import "./gelu-Bd3UBBxg.js";
41
41
  import "./ops/webgl/log.js";
42
- import { t as p } from "./tensor2d-wxPAnDQy.js";
43
- import { c as f } from "./concat-CsxrgovM.js";
44
- const k = [
45
- ...Array.from({ length: 95 }, (r, t) => String.fromCharCode(t + 32)),
42
+ import $ from "./utilities/multinomialCPU.js";
43
+ import { r as x } from "./reshape-DEfQGSin.js";
44
+ import { t as P } from "./tensor2d-BHy261cI.js";
45
+ import { s as v } from "./softmax-DLoZWYBx.js";
46
+ import { g as q } from "./gather-C1siEkdp.js";
47
+ import { c as G } from "./concat-BmDqqFsa.js";
48
+ /**
49
+ * @license
50
+ * Copyright 2020 Google LLC. All Rights Reserved.
51
+ * Licensed under the Apache License, Version 2.0 (the "License");
52
+ * you may not use this file except in compliance with the License.
53
+ * You may obtain a copy of the License at
54
+ *
55
+ * http://www.apache.org/licenses/LICENSE-2.0
56
+ *
57
+ * Unless required by applicable law or agreed to in writing, software
58
+ * distributed under the License is distributed on an "AS IS" BASIS,
59
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60
+ * See the License for the specific language governing permissions and
61
+ * limitations under the License.
62
+ * =============================================================================
63
+ */
64
+ function N(h, t, e, i = !1) {
65
+ const o = L(h, "logits", "multinomial"), s = o.size, n = o.rank;
66
+ if (s < 2)
67
+ throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
68
+ if (n > 2)
69
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
70
+ e = e || Math.random();
71
+ const a = { logits: n === 1 ? x(o, [1, -1]) : o }, p = { numSamples: t, seed: e, normalized: i }, l = C.runKernel(I, a, p);
72
+ return n === 1 ? x(l, [l.size]) : l;
73
+ }
74
+ const S = /* @__PURE__ */ A({ multinomial_: N }), B = [
75
+ ...Array.from({ length: 95 }, (h, t) => String.fromCharCode(t + 32)),
46
76
  // ASCII
47
77
  // Spanish accented letters and punctuation
48
78
  ..."áéíóúüñ¿¡",
@@ -53,12 +83,12 @@ const k = [
53
83
  // Cyrillic letters
54
84
  ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
55
85
  ];
56
- function d(r, t) {
57
- return r.length === t ? r : r.length > t ? r.slice(0, t) : r.concat(Array(t - r.length).fill(""));
86
+ function H(h, t) {
87
+ return h.length === t ? h : h.length > t ? h.slice(0, t) : h.concat(Array(t - h.length).fill(""));
58
88
  }
59
- class nt extends l {
60
- constructor(t, i) {
61
- super(), this.model = t, this.tokeniser = i, this.actualTokeniser = i;
89
+ class Mt extends z {
90
+ constructor(t, e) {
91
+ super(), this.model = t, this.tokeniser = e, this.actualTokeniser = e;
62
92
  }
63
93
  active = !1;
64
94
  cache = null;
@@ -66,71 +96,133 @@ class nt extends l {
66
96
  outputText = "";
67
97
  actualTokeniser;
68
98
  lastToken = -1;
69
- async tokenisePrompt(t, i) {
70
- const e = i ? await t.tokenise([i], !0) : [[t.eosToken]];
71
- return p(e, [1, e[0].length], "int32");
99
+ attentionData = [];
100
+ probabilitiesData = [];
101
+ embeddingsData = [];
102
+ tokens = [];
103
+ async tokenisePrompt(t, e) {
104
+ const i = e ? await t.tokenise([e], !0) : [[t.eosToken]];
105
+ return P(i, [1, i[0].length], "int32");
72
106
  }
73
- async processResponse(t, i, e, o) {
74
- const s = (await i.array())[0][0];
107
+ async processResponse(t, e, i, o) {
108
+ const s = (await e.array())[0][0];
75
109
  if (this.lastToken = s, s === this.tokeniser.eosToken)
76
110
  return null;
77
111
  const n = await t.decode([s]);
78
- let c;
79
- e && (c = await Promise.all(e.map((h) => h.array().then((m) => m))), e.forEach((h) => h.dispose()));
80
- let a;
81
- return o && (a = await o.array(), o.dispose()), this.emit("tokens", [s], n, c, a), n;
112
+ if (i) {
113
+ const d = await Promise.all(i.map((a) => a.array().then((p) => p)));
114
+ i.forEach((a) => a.dispose()), this.attentionData.push(d);
115
+ }
116
+ if (o) {
117
+ const d = await o.array();
118
+ o.dispose(), this.probabilitiesData.push(d);
119
+ }
120
+ return this.tokens.push(s), this.emit("tokens", [s], n), n;
121
+ }
122
+ /** Generate logits and select a token. */
123
+ async _generateToken(t, e, i) {
124
+ const o = i?.temperature ?? 1, s = i?.topK, n = i?.topP, d = i?.usePadding ?? !1, a = {
125
+ training: !1,
126
+ attentionScores: i?.attentionScores ? {
127
+ attentionOut: []
128
+ } : void 0,
129
+ cache: e,
130
+ outputEmbeddings: i?.embeddings ?? !1
131
+ }, p = O(() => {
132
+ const r = t, m = r.shape[1], u = m <= this.model.config.blockSize ? r : r.slice(
133
+ [0, m - this.model.config.blockSize],
134
+ [r.shape[0], this.model.config.blockSize]
135
+ ), k = d ? this.model.config.blockSize - u.shape[1] : 0, b = k > 0 ? _(u, [
136
+ [0, 0],
137
+ [0, k]
138
+ ]) : u, [f] = this.model.forward(a, b), y = f.shape[1] - 1 - k, c = f.slice([0, y, 0], [f.shape[0], 1, f.shape[2]]);
139
+ return a.attentionScores?.attentionOut && a.attentionScores.attentionOut.forEach((T, E) => {
140
+ T.shape[1] !== 1 && (a.attentionScores.attentionOut[E] = R(
141
+ T.slice([0, y, 0], [T.shape[0], 1, T.shape[2]])
142
+ ), T.dispose());
143
+ }), f.dispose(), c.div(o).squeeze([1]);
144
+ });
145
+ let l;
146
+ if (n) {
147
+ const r = v(p), m = await r.array();
148
+ r.dispose();
149
+ const u = m[0].map((c, g) => ({ prob: c, index: g })).sort((c, g) => g.prob - c.prob);
150
+ let k = 0;
151
+ const b = new Array(u.length).fill(0);
152
+ for (const c of u)
153
+ if (k += c.prob, b[c.index] = c.prob, k >= n)
154
+ break;
155
+ const f = b.reduce((c, g) => c + g, 0), y = b.map((c) => c / f);
156
+ l = $(y);
157
+ } else if (s) {
158
+ const { values: r, indices: m } = K(p, s), u = S(r, 1);
159
+ l = q(m, u, 1), r.dispose(), m.dispose(), u.dispose();
160
+ } else
161
+ l = S(p, 1);
162
+ let w;
163
+ i?.includeProbabilities && (w = v(p)), a.embeddings && this.embeddingsData.push(
164
+ await Promise.all(
165
+ a.embeddings.map(async (r) => {
166
+ const m = await r.array();
167
+ return r.dispose(), m;
168
+ })
169
+ )
170
+ );
171
+ const D = l.reshape([1, 1]);
172
+ return l.dispose(), l = D, p.dispose(), { output: l, probabilities: w, attention: a.attentionScores?.attentionOut };
82
173
  }
174
+ /** Generate multiple tokens in a loop and produce text */
83
175
  async _generate(t) {
84
- let i = this.lastToken >= 0 && this.cache ? p([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
85
- const e = t?.maxLength ?? 1e3;
86
- for (let o = 0; o < e && this.active; o++) {
176
+ let e = this.lastToken >= 0 && this.cache ? P([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
177
+ const i = t?.maxLength ?? 1e3;
178
+ for (let o = 0; o < i && this.active; o++) {
87
179
  const {
88
180
  output: s,
89
181
  probabilities: n,
90
- attention: c
91
- } = await this.model.generate(i, this.cache ? this.cache : void 0, {
182
+ attention: d
183
+ } = await this._generateToken(e, this.cache ? this.cache : void 0, {
92
184
  ...t,
93
185
  usePadding: !this.cache
94
186
  });
95
187
  if (this.cache)
96
- i.dispose(), i = s;
188
+ e.dispose(), e = s;
97
189
  else {
98
- const h = i;
99
- i = f([i, s], 1), h.dispose();
190
+ const p = e;
191
+ e = G([e, s], 1), p.dispose();
100
192
  }
101
- const a = await this.processResponse(this.actualTokeniser, s, c, n);
193
+ const a = await this.processResponse(this.actualTokeniser, s, d, n);
102
194
  if (this.cache || s.dispose(), a === null)
103
195
  break;
104
196
  this.outputText += a;
105
197
  }
106
- return i.dispose(), this.outputText;
198
+ return e.dispose(), this.outputText;
107
199
  }
108
200
  reset() {
109
201
  this.cache && (this.cache.forEach((t) => {
110
202
  t && (t.k && t.k.dispose(), t.v && t.v.dispose());
111
- }), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1;
203
+ }), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1, this.attentionData = [], this.probabilitiesData = [], this.tokens = [];
112
204
  }
113
205
  dispose() {
114
206
  this.reset();
115
207
  }
116
- initialise(t, i) {
117
- const e = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t ?? null;
118
- if (this.cache && i?.noCache && this.reset(), this.initialPrompt = e || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !i?.noCache && this.model.config.gpt.useRope) {
119
- const s = new Array(this.model.config.gpt.nLayer);
120
- for (let n = 0; n < this.model.config.gpt.nLayer; n++)
208
+ initialise(t, e) {
209
+ const i = t && t.length > this.model.config.blockSize ? t.slice(-this.model.config.blockSize) : t ?? null;
210
+ if (this.cache && e?.noCache && this.reset(), this.initialPrompt = i || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !e?.noCache && this.model.config.useRope) {
211
+ const s = new Array(this.model.config.nLayer);
212
+ for (let n = 0; n < this.model.config.nLayer; n++)
121
213
  s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
122
214
  this.cache = s, this.lastToken = -1;
123
215
  }
124
- const o = this.tokeniser.trained ? this.tokeniser : new u(d(k, this.tokeniser.vocabSize));
216
+ const o = this.tokeniser.trained ? this.tokeniser : new M(H(B, this.tokeniser.vocabSize));
125
217
  this.actualTokeniser = o;
126
218
  }
127
- async step(t, i) {
128
- const e = { ...i, maxLength: 1 };
129
- return this.generate(t, e);
219
+ async step(t, e) {
220
+ const i = { ...e, maxLength: 1 };
221
+ return this.generate(t, i);
130
222
  }
131
- async generate(t, i) {
132
- this.initialise(t, i), this.active = !0, this.emit("start");
133
- const o = await this._generate(i);
223
+ async generate(t, e) {
224
+ this.initialise(t, e), this.active = !0, this.emit("start");
225
+ const o = await this._generate(e);
134
226
  return this.active = !1, this.emit("stop"), o;
135
227
  }
136
228
  stop() {
@@ -139,7 +231,16 @@ class nt extends l {
139
231
  getText() {
140
232
  return this.outputText;
141
233
  }
234
+ getAttentionData() {
235
+ return this.attentionData;
236
+ }
237
+ getProbabilitiesData() {
238
+ return this.probabilitiesData;
239
+ }
240
+ getTokens() {
241
+ return this.tokens;
242
+ }
142
243
  }
143
244
  export {
144
- nt as default
245
+ Mt as default
145
246
  };
@@ -1,10 +1,10 @@
1
- import { aq as T, ac as E, p as O, j as V, ay as B, Y as F, U, az as j } from "./index-BoWRt-10.js";
2
- import { r as $ } from "./Reshape-DH5srBP0.js";
3
- import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-BzbKo31C.js";
4
- import { t as K, m as W } from "./shared-wS99K7_n.js";
5
- import { c as _ } from "./backend_util-TE7aTPhZ.js";
6
- import { f as y } from "./gpgpu_math-DGNLNL4I.js";
7
- import { g as G, b as L } from "./kernel_funcs_utils-BYKWV8Aa.js";
1
+ import { as as T, af as E, p as O, j as V, aA as B, a0 as F, X as j, aB as K } from "./index-CUQrfsw_.js";
2
+ import { r as $ } from "./Reshape-Bo8HzP8V.js";
3
+ import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-DubwyOhW.js";
4
+ import { t as U, m as W } from "./shared-D1elLckx.js";
5
+ import { c as _ } from "./backend_util-BJ-_jSeK.js";
6
+ import { f as y } from "./gpgpu_math-TFLxaLkw.js";
7
+ import { g as G, b as L } from "./kernel_funcs_utils-P9aFa232.js";
8
8
  /**
9
9
  * @license
10
10
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -105,7 +105,7 @@ class w {
105
105
  * limitations under the License.
106
106
  * =============================================================================
107
107
  */
108
- class q {
108
+ class X {
109
109
  constructor(s, e) {
110
110
  this.variableNames = ["x"];
111
111
  const { windowSize: t, batchSize: n, inSize: l, outSize: r } = s;
@@ -229,7 +229,7 @@ class q {
229
229
  * limitations under the License.
230
230
  * =============================================================================
231
231
  */
232
- function X(a) {
232
+ function q(a) {
233
233
  const s = [];
234
234
  for (; s.length === 0 || s[s.length - 1].outSize !== 1; ) {
235
235
  const e = s.length ? s[s.length - 1].outSize : a[1], t = _(e);
@@ -242,12 +242,12 @@ function X(a) {
242
242
  return s;
243
243
  }
244
244
  function P(a, s, e, t) {
245
- const n = X(a.shape);
245
+ const n = q(a.shape);
246
246
  let l = a;
247
247
  for (let r = 0; r < n.length; r++) {
248
248
  const { inSize: i, windowSize: c, outSize: o } = n[r];
249
249
  let u, p;
250
- e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new q({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
250
+ e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new X({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
251
251
  }
252
252
  return l;
253
253
  }
@@ -459,7 +459,7 @@ function te(a) {
459
459
  const I = e.texData.get(d.dataId).values, m = new Array(i);
460
460
  for (let v = 0; v < m.length; v++)
461
461
  m[v] = n.shape[u[v]];
462
- const z = K(I, n.shape, n.dtype, u, m);
462
+ const z = U(I, n.shape, n.dtype, u, m);
463
463
  d = e.makeTensorInfo(m, n.dtype);
464
464
  const M = e.texData.get(d.dataId);
465
465
  M.values = z;
@@ -482,7 +482,7 @@ function te(a) {
482
482
  return p && e.disposeIntermediateTensorInfo(d), x;
483
483
  }
484
484
  const he = {
485
- kernelName: U,
485
+ kernelName: j,
486
486
  backendName: "webgl",
487
487
  kernelFunc: te
488
488
  };
@@ -525,7 +525,7 @@ return a / b;`, se = `
525
525
 
526
526
  return result;
527
527
  `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
528
- kernelName: j,
528
+ kernelName: K,
529
529
  backendName: "webgl",
530
530
  kernelFunc: ne
531
531
  };
@@ -1,4 +1,4 @@
1
- import { j as h, a3 as d, l as c, K as m } from "./index-BoWRt-10.js";
1
+ import { j as h, a4 as d, n as c, U as m } from "./index-CUQrfsw_.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2021 Google LLC. All Rights Reserved.
@@ -1,5 +1,5 @@
1
- import { j as c, a3 as C, l as f, K as R } from "./index-BoWRt-10.js";
2
- import { u as g, g as I, a as x, b as F, c as $, d as u, e as l, i as m } from "./gpgpu_math-DGNLNL4I.js";
1
+ import { j as c, a4 as C, n as f, U as R } from "./index-CUQrfsw_.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-TFLxaLkw.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -82,14 +82,14 @@ function v(s, t) {
82
82
  function b(s, t, i) {
83
83
  const a = [
84
84
  u(s.shape),
85
- ...l(s.shape)
85
+ ...m(s.shape)
86
86
  ], e = {
87
87
  dtype: s.dtype,
88
88
  shape: a,
89
89
  dataId: s.dataId
90
90
  }, o = [
91
91
  u(t),
92
- ...l(t)
92
+ ...m(t)
93
93
  ], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
94
94
  return { dataId: h.dataId, shape: t, dtype: h.dtype };
95
95
  }
@@ -113,7 +113,7 @@ function y(s) {
113
113
  const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = C(o, p), h = c(n);
114
114
  f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
115
115
  const d = r.texData.get(e.dataId);
116
- return d.isPacked && !m(e.shape, n) && !(d.texture !== null && m(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
116
+ return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
117
117
  }
118
118
  const U = {
119
119
  kernelName: R,
@@ -1,11 +1,11 @@
1
- import { GPTConfig } from './config';
1
+ import { GPTConfig } from './models/config';
2
2
  import { ITokeniser } from './tokeniser/type';
3
- import { default as NanoGPT, TrainingLogEntry } from './NanoGPTModel';
4
- import { SaveOptions } from './utilities/save';
3
+ import { SaveOptions } from './loader/save';
5
4
  import { default as Generator, IGenerateOptions } from './Generator';
6
5
  import { default as Trainer, ITrainerOptions } from './Trainer';
7
6
  import { default as MemoryProfiler } from './utilities/profile';
8
- import { TrainingProgress } from './training/Trainer';
7
+ import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
8
+ import { default as Model, ModelForwardAttributes } from './models/model';
9
9
  type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
10
10
  interface TeachableLLMMeta {
11
11
  name?: string;
@@ -20,12 +20,12 @@ export default class TeachableLLM {
20
20
  private _status;
21
21
  private _memoryRequirements?;
22
22
  meta: TeachableLLMMeta;
23
- constructor(tokeniser?: ITokeniser, model?: NanoGPT);
23
+ constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes>);
24
24
  get vocab(): string[];
25
25
  /** Model is fully loaded */
26
26
  get loaded(): boolean;
27
27
  get config(): GPTConfig;
28
- get model(): NanoGPT;
28
+ get model(): Model<ModelForwardAttributes>;
29
29
  get tokeniser(): ITokeniser;
30
30
  get status(): TeachableLLMStatus;
31
31
  /** Model is both ready and not busy */
@@ -1,30 +1,21 @@
1
- import { defaultConfig as _ } from "./config.js";
2
- import f from "./NanoGPTModel.js";
3
- import { saveModel as d } from "./utilities/save.js";
4
- import { loadModel as l } from "./loader/load.js";
1
+ import { defaultConfig as d } from "./models/config.js";
2
+ import { saveModel as l } from "./loader/save.js";
3
+ import { loadModel as _ } from "./loader/load.js";
5
4
  import u from "./Generator.js";
6
- import p from "./Trainer.js";
7
- import { E as c } from "./index-Dwqa6Zy2.js";
5
+ import f from "./Trainer.js";
6
+ import { E as p } from "./index-Dwqa6Zy2.js";
8
7
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
9
- import g from "./tokeniser/CharTokeniser.js";
10
- import k from "./tokeniser/bpe.js";
11
- import "./papaparse.min-C8l2Kvo1.js";
12
- import "./index-Tf7vU29b.js";
13
- import "./jszip.min-CjP2V1VV.js";
14
- import "./index-BoWRt-10.js";
15
- import "./ops/cpu/scatterSub.js";
16
- import "./ops/webgl/scatterSub.js";
17
- import "./ops/cpu/gatherSub.js";
18
- import "./ops/webgl/gatherSub.js";
8
+ import "./index-CUQrfsw_.js";
19
9
  import "./ops/cpu/attentionMask.js";
20
10
  import "./ops/webgl/attentionMask.js";
21
11
  import "./ops/grads/attentionMask.js";
22
12
  import "./ops/cpu/qkv.js";
23
13
  import "./ops/webgl/qkv.js";
24
14
  import "./ops/grads/qkv.js";
25
- import "./random_width-sZORGo5k.js";
26
- import "./register_all_kernels-BwDSRN-f.js";
27
- import "./dataset-CtdBYwjo.js";
15
+ import "./random_width-D8Pwy_na.js";
16
+ import "./register_all_kernels-DUshvVWP.js";
17
+ import "./index-Tf7vU29b.js";
18
+ import "./dataset-CJmEGu6D.js";
28
19
  import "./ops/cpu/rope.js";
29
20
  import "./ops/webgl/rope.js";
30
21
  import "./ops/grads/rope.js";
@@ -36,20 +27,29 @@ import "./ops/grads/fusedSoftmax.js";
36
27
  import "./ops/cpu/matMulGelu.js";
37
28
  import "./ops/webgl/matMulGelu.js";
38
29
  import "./ops/grads/matMulGelu.js";
39
- import "./ops/cpu/gelu.js";
40
- import "./ops/webgl/gelu.js";
41
- import "./gelu-C-dPj6Ku.js";
42
30
  import "./ops/cpu/normRMS.js";
43
31
  import "./ops/webgl/normRMS.js";
44
32
  import "./ops/grads/normRMS.js";
33
+ import "./ops/cpu/gatherSub.js";
34
+ import "./ops/webgl/gatherSub.js";
35
+ import "./ops/cpu/scatterSub.js";
36
+ import "./ops/webgl/scatterSub.js";
37
+ import c from "./tokeniser/CharTokeniser.js";
38
+ import g from "./tokeniser/bpe.js";
39
+ import "./papaparse.min-C8l2Kvo1.js";
40
+ import "./jszip.min-CjP2V1VV.js";
41
+ import "./ops/cpu/gelu.js";
42
+ import "./ops/webgl/gelu.js";
43
+ import "./gelu-Bd3UBBxg.js";
45
44
  import "./ops/webgl/log.js";
46
45
  import "./ops/cpu/adamMoments.js";
47
46
  import "./ops/webgl/adamMoments.js";
48
47
  import "./ops/cpu/adamAdjust.js";
49
48
  import "./ops/webgl/adamAdjust.js";
50
- import w from "./utilities/profile.js";
49
+ import k from "./utilities/profile.js";
50
+ import w from "./models/factory.js";
51
51
  class a {
52
- ee = new c();
52
+ ee = new p();
53
53
  _config;
54
54
  _model;
55
55
  _tokeniser;
@@ -69,7 +69,7 @@ class a {
69
69
  get config() {
70
70
  if (!this._config)
71
71
  throw new Error("configuration_not_initialized.");
72
- return this._config.gpt;
72
+ return this._config;
73
73
  }
74
74
  get model() {
75
75
  if (!this._model)
@@ -101,14 +101,14 @@ class a {
101
101
  saveModel(t) {
102
102
  if (!this._model || !this._tokeniser)
103
103
  throw new Error("model_or_tokeniser_not_initialized.");
104
- return d(this._model, this._tokeniser, {
104
+ return l(this._model, this._tokeniser, {
105
105
  ...t,
106
106
  name: t?.name || this.meta.name
107
107
  });
108
108
  }
109
109
  static loadModel(t) {
110
110
  const e = new a();
111
- return l(t).then(({ model: r, tokeniser: o, name: s }) => {
111
+ return _(t).then(({ model: r, tokeniser: o, name: s }) => {
112
112
  e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
113
113
  e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
114
114
  }).catch((i) => {
@@ -119,7 +119,7 @@ class a {
119
119
  }), e;
120
120
  }
121
121
  static create(t, e = {}) {
122
- const r = { ..._, ...e }, o = t === "char" ? new g(r.vocabSize) : new k(r.vocabSize), s = new f(r), i = new a(o, s);
122
+ const r = { ...d, ...e }, o = t === "char" ? new c(r.vocabSize) : new g(r.vocabSize), s = w(r), i = new a(o, s);
123
123
  return i.setStatus("warmup"), m(s).then((n) => {
124
124
  i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
125
125
  h === "trained" && i.setStatus("ready");
@@ -138,9 +138,9 @@ class a {
138
138
  if (t) {
139
139
  if (!this._config)
140
140
  return;
141
- this._config.layerConfig.profiler || (this._config.layerConfig.profiler = new w());
141
+ this.model.getProfiler() || this.model.setProfiler(new k());
142
142
  } else
143
- this._config?.layerConfig.profiler && (this._config.layerConfig.profiler = void 0);
143
+ this.model.getProfiler() && this.model.setProfiler(null);
144
144
  }
145
145
  getNumParams() {
146
146
  return this._model ? this._model.getNumParams() : 0;
@@ -148,7 +148,7 @@ class a {
148
148
  trainer() {
149
149
  if (!this._model || !this._tokeniser)
150
150
  throw new Error("model_or_tokeniser_not_initialized.");
151
- const t = new p(this._model, this._tokeniser);
151
+ const t = new f(this._model, this._tokeniser);
152
152
  return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
153
153
  const o = this.ee.listeners("trainStep");
154
154
  for (const s of o)
package/dist/Trainer.d.ts CHANGED
@@ -1,6 +1,7 @@
1
- import { default as NanoGPT } from './NanoGPTModel';
2
1
  import { ITokeniser } from './tokeniser/type';
3
2
  import { default as EE } from 'eventemitter3';
3
+ import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
4
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
5
  export interface ITrainerOptions {
5
6
  batchSize?: number;
6
7
  learningRate?: number;
@@ -10,6 +11,11 @@ export interface ITrainerOptions {
10
11
  prompt?: string;
11
12
  validationSplit?: number;
12
13
  advancedMetrics?: boolean;
14
+ gradientCheckpointing?: boolean;
15
+ }
16
+ interface ExtendedTrainingProgress extends TrainingProgress {
17
+ progress: number;
18
+ remaining: number;
13
19
  }
14
20
  export default class Trainer extends EE<'start' | 'stop' | 'log'> {
15
21
  private trainer;
@@ -17,10 +23,15 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
17
23
  private trainDataset?;
18
24
  private validationDataset?;
19
25
  private totalSamples;
20
- constructor(model: NanoGPT, tokeniser: ITokeniser);
26
+ private log;
27
+ private progress;
28
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
21
29
  stop(): void;
22
30
  reset(): void;
23
31
  prepare(text: string[], options?: ITrainerOptions): Promise<void>;
24
32
  train(options?: ITrainerOptions): Promise<void>;
25
33
  step(options?: ITrainerOptions): Promise<void>;
34
+ getLog(): TrainingLogEntry[];
35
+ getProgress(): ExtendedTrainingProgress | null;
26
36
  }
37
+ export {};