@genai-fi/nanogpt 0.10.2 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. package/dist/Generator.d.ts +10 -5
  2. package/dist/Generator.js +11760 -146
  3. package/dist/{RealDiv-zz7FpkKX.js → RealDiv-Ds-jvL09.js} +28 -30
  4. package/dist/Reshape-Cd6e-Otn.js +14 -0
  5. package/dist/{Reshape-CHdUjC72.js → Reshape-Ct266DEk.js} +21 -23
  6. package/dist/TeachableLLM.d.ts +4 -3
  7. package/dist/TeachableLLM.js +15 -16
  8. package/dist/Trainer.d.ts +2 -2
  9. package/dist/Trainer.js +6 -6
  10. package/dist/{axis_util-BsIr9ZNu.js → axis_util-DofAuy0p.js} +1 -1
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-B1XRLuq9.js → backend_util-C7NWHpv7.js} +72 -73
  13. package/dist/{backend_webgpu-CqpfEImu.js → backend_webgpu-B0Vls736.js} +52 -54
  14. package/dist/broadcast_to-DDaNMbX7.js +28 -0
  15. package/dist/checks/appendCache.js +2 -2
  16. package/dist/checks/attentionMask.js +3 -3
  17. package/dist/checks/gelu.js +2 -2
  18. package/dist/checks/matMulGelu.js +7 -11
  19. package/dist/checks/normRMS.js +9 -9
  20. package/dist/checks/normRMSGrad.js +3 -3
  21. package/dist/checks/packUnpack.js +2 -2
  22. package/dist/checks/qkv.js +11 -12
  23. package/dist/checks/rope.js +2 -2
  24. package/dist/clip_by_value-Dn5tzexi.js +12 -0
  25. package/dist/complex-DClmWqJt.js +11 -0
  26. package/dist/concat-C6X3AAlQ.js +17 -0
  27. package/dist/{concat_util-iBYIyuQe.js → concat_util-CHsJFZJJ.js} +1 -1
  28. package/dist/{dataset-D2P7rHAw.js → dataset-DcjWqUVQ.js} +135 -137
  29. package/dist/dropout-OxuaJz6z.js +92 -0
  30. package/dist/expand_dims-BzfJK2uc.js +11 -0
  31. package/dist/{exports_initializers-CZSUJoVE.js → exports_initializers-eS9QJ6ut.js} +1 -1
  32. package/dist/floor-DIb-lN_u.js +9 -0
  33. package/dist/gather-BcO5UQNJ.js +9 -0
  34. package/dist/{gelu-Bmhopi0J.js → gelu-DqTbCx5x.js} +10 -11
  35. package/dist/{gpgpu_math-DsCcikas.js → gpgpu_math-CJcbnKPC.js} +841 -1015
  36. package/dist/index-D0RBWjq8.js +3520 -0
  37. package/dist/{index-DRyE072i.js → index-Dj5TkmPY.js} +330 -331
  38. package/dist/{kernel_funcs_utils-CWfOAPGO.js → kernel_funcs_utils-CSaumNDs.js} +132 -134
  39. package/dist/layers/BaseLayer.js +15 -16
  40. package/dist/layers/CausalSelfAttention.js +6 -6
  41. package/dist/layers/MLP.js +4 -4
  42. package/dist/layers/PositionEmbedding.js +7 -7
  43. package/dist/layers/RMSNorm.js +3 -3
  44. package/dist/layers/RoPECache.js +9 -9
  45. package/dist/layers/TiedEmbedding.js +6 -6
  46. package/dist/layers/TransformerBlock.js +1 -1
  47. package/dist/loader/loadTransformers.js +1 -1
  48. package/dist/loader/oldZipLoad.js +21 -22
  49. package/dist/log_sum_exp-VLZgbFAH.js +39 -0
  50. package/dist/main.d.ts +1 -1
  51. package/dist/main.js +49 -50
  52. package/dist/{matMul16-fEAJ4smh.js → matMul16-cDxwemKj.js} +14 -15
  53. package/dist/matMulGelu-B2s_80-H.js +163 -0
  54. package/dist/mat_mul-DxpNTCRz.js +11 -0
  55. package/dist/mod-PrOKlFxH.js +11 -0
  56. package/dist/models/NanoGPTV1.js +2 -2
  57. package/dist/models/model.js +13 -14
  58. package/dist/ones-BX_wEgzB.js +14 -0
  59. package/dist/ops/adamAdjust.js +1 -1
  60. package/dist/ops/adamMoments.js +1 -1
  61. package/dist/ops/add16.js +1 -1
  62. package/dist/ops/appendCache.js +3 -3
  63. package/dist/ops/attentionMask.js +1 -1
  64. package/dist/ops/concat16.js +2 -2
  65. package/dist/ops/cpu/adamAdjust.js +12 -13
  66. package/dist/ops/cpu/adamMoments.js +6 -7
  67. package/dist/ops/cpu/appendCache.js +7 -8
  68. package/dist/ops/cpu/attentionMask.js +11 -11
  69. package/dist/ops/cpu/fusedSoftmax.js +10 -11
  70. package/dist/ops/cpu/gatherSub.js +10 -11
  71. package/dist/ops/cpu/gelu.js +14 -15
  72. package/dist/ops/cpu/matMul16.js +6 -7
  73. package/dist/ops/cpu/matMulGelu.js +5 -6
  74. package/dist/ops/cpu/matMulMul.js +3 -4
  75. package/dist/ops/cpu/mulDropout.js +3 -4
  76. package/dist/ops/cpu/normRMS.js +11 -12
  77. package/dist/ops/cpu/qkv.js +8 -9
  78. package/dist/ops/cpu/rope.js +9 -10
  79. package/dist/ops/cpu/scatterSub.js +14 -16
  80. package/dist/ops/dot16.js +2 -2
  81. package/dist/ops/gatherSub.js +1 -1
  82. package/dist/ops/gelu.js +2 -2
  83. package/dist/ops/grads/add16.js +10 -11
  84. package/dist/ops/grads/attentionMask.js +5 -6
  85. package/dist/ops/grads/gelu.js +3 -4
  86. package/dist/ops/grads/matMul16.js +4 -5
  87. package/dist/ops/grads/matMulGelu.js +8 -9
  88. package/dist/ops/grads/normRMS.js +9 -10
  89. package/dist/ops/grads/pack16.js +4 -5
  90. package/dist/ops/grads/qkv.js +17 -19
  91. package/dist/ops/grads/rope.js +3 -5
  92. package/dist/ops/grads/softmax16.js +3 -4
  93. package/dist/ops/grads/unpack16.js +3 -4
  94. package/dist/ops/grads/utils.d.ts +1 -0
  95. package/dist/ops/grads/utils.js +8 -4
  96. package/dist/ops/matMul16.js +3 -3
  97. package/dist/ops/matMulGelu.js +2 -2
  98. package/dist/ops/matMulMul.js +1 -1
  99. package/dist/ops/mul16.js +1 -1
  100. package/dist/ops/mulDrop.js +1 -1
  101. package/dist/ops/normRMS.js +1 -1
  102. package/dist/ops/pack16.js +3 -4
  103. package/dist/ops/qkv.js +4 -8
  104. package/dist/ops/reshape16.js +16 -18
  105. package/dist/ops/rope.d.ts +1 -1
  106. package/dist/ops/rope.js +3 -8
  107. package/dist/ops/scatterSub.js +1 -1
  108. package/dist/ops/slice16.js +2 -2
  109. package/dist/ops/softmax16.js +5 -8
  110. package/dist/ops/sub16.js +1 -1
  111. package/dist/ops/sum16.js +2 -2
  112. package/dist/ops/transpose16.js +23 -24
  113. package/dist/ops/unpack16.js +2 -2
  114. package/dist/ops/webgl/adamAdjust.js +2 -3
  115. package/dist/ops/webgl/adamMoments.js +1 -2
  116. package/dist/ops/webgl/appendCache.js +1 -2
  117. package/dist/ops/webgl/attentionMask.js +5 -6
  118. package/dist/ops/webgl/fusedSoftmax.js +6 -8
  119. package/dist/ops/webgl/gatherSub.js +6 -7
  120. package/dist/ops/webgl/gelu.js +2 -3
  121. package/dist/ops/webgl/log.js +11 -12
  122. package/dist/ops/webgl/matMul16.js +15 -16
  123. package/dist/ops/webgl/matMulGelu.js +7 -111
  124. package/dist/ops/webgl/matMulMul.js +14 -15
  125. package/dist/ops/webgl/mulDropout.js +8 -9
  126. package/dist/ops/webgl/normRMS.js +7 -8
  127. package/dist/ops/webgl/qkv.js +5 -6
  128. package/dist/ops/webgl/rope.js +7 -8
  129. package/dist/ops/webgl/scatterSub.js +5 -6
  130. package/dist/ops/webgpu/adamAdjust.js +10 -12
  131. package/dist/ops/webgpu/adamMoments.js +8 -10
  132. package/dist/ops/webgpu/add16.js +8 -9
  133. package/dist/ops/webgpu/appendCache.js +23 -25
  134. package/dist/ops/webgpu/attentionMask.js +10 -12
  135. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  136. package/dist/ops/webgpu/concat16.js +12 -14
  137. package/dist/ops/webgpu/gatherSub.js +9 -11
  138. package/dist/ops/webgpu/gelu.js +28 -29
  139. package/dist/ops/webgpu/matMul16.js +26 -28
  140. package/dist/ops/webgpu/matMul16_program.js +4 -5
  141. package/dist/ops/webgpu/mul16.js +7 -8
  142. package/dist/ops/webgpu/normRMS.js +17 -19
  143. package/dist/ops/webgpu/normRMSGrad.js +21 -28
  144. package/dist/ops/webgpu/pack16.js +12 -13
  145. package/dist/ops/webgpu/pack16_program.js +2 -2
  146. package/dist/ops/webgpu/qkv.js +13 -15
  147. package/dist/ops/webgpu/rope.js +25 -27
  148. package/dist/ops/webgpu/scatterSub.js +7 -9
  149. package/dist/ops/webgpu/slice16.js +21 -23
  150. package/dist/ops/webgpu/softmax16.js +17 -19
  151. package/dist/ops/webgpu/softmax16_program.js +2 -2
  152. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  153. package/dist/ops/webgpu/softmax16grad.js +7 -8
  154. package/dist/ops/webgpu/sub16.js +8 -9
  155. package/dist/ops/webgpu/sum16.js +19 -21
  156. package/dist/ops/webgpu/transpose16.js +19 -20
  157. package/dist/ops/webgpu/transpose16_program.js +2 -2
  158. package/dist/ops/webgpu/transpose16_shared_program.js +11 -12
  159. package/dist/ops/webgpu/unpack16.js +3 -4
  160. package/dist/ops/webgpu/utils/binary_op.js +7 -8
  161. package/dist/ops/webgpu/utils/reductions.js +14 -22
  162. package/dist/ops-FJapAPfm.js +476 -0
  163. package/dist/pack16-k4jq6aMX.js +39 -0
  164. package/dist/patches/webgpu_backend.js +19 -20
  165. package/dist/patches/webgpu_base.js +1 -1
  166. package/dist/patches/webgpu_program.js +15 -16
  167. package/dist/{random_width-BVV9HveY.js → random_width-UGQn4OWb.js} +2506 -2761
  168. package/dist/range-CuGvVN2c.js +10 -0
  169. package/dist/relu-Cf80uA2p.js +9 -0
  170. package/dist/reshape-CkjKPPqB.js +9 -0
  171. package/dist/resize_nearest_neighbor-DB8k9KN_.js +175 -0
  172. package/dist/rope-BmZmp9uP.js +24 -0
  173. package/dist/{scatter_nd_util-C7zXRT_h.js → scatter_nd_util-BY22Cc-C.js} +1 -1
  174. package/dist/selu_util-BuLbmbrl.js +44 -0
  175. package/dist/{shared-CHhxz-O5.js → shared-B7USJZgw.js} +1 -1
  176. package/dist/{shared-D2NP_CpY.js → shared-BQboIImQ.js} +379 -381
  177. package/dist/slice-Aqy7KbJh.js +12 -0
  178. package/dist/{slice_util-DyjSAD0u.js → slice_util-D8CQRenR.js} +7 -7
  179. package/dist/{softmax-C9JQEtnO.js → softmax-faLoUZVT.js} +4 -5
  180. package/dist/split-BNz5jcGc.js +9 -0
  181. package/dist/squeeze--YMgaAAf.js +10 -0
  182. package/dist/stack-WJK22CFn.js +11 -0
  183. package/dist/step-dXR33iOg.js +261 -0
  184. package/dist/sum-BdplSvq_.js +11 -0
  185. package/dist/{tensor-0r5yOo2R.js → tensor-BQqrDvpx.js} +1 -1
  186. package/dist/tensor1d-LxP9asMm.js +11 -0
  187. package/dist/{tensor2d-CSB4KOb0.js → tensor2d-BN1sSfQO.js} +6 -7
  188. package/dist/{tensor4d-D7bLqGqz.js → tensor4d-DVwr7pLF.js} +6 -7
  189. package/dist/{tfjs_backend-CNkSTL0c.js → tfjs_backend-Vi4JfLzT.js} +256 -265
  190. package/dist/tile-CvN_LyVr.js +11 -0
  191. package/dist/tokeniser/BaseTokeniser.d.ts +27 -0
  192. package/dist/tokeniser/BaseTokeniser.js +94 -0
  193. package/dist/tokeniser/CharTokeniser.d.ts +4 -3
  194. package/dist/tokeniser/CharTokeniser.js +46 -32
  195. package/dist/tokeniser/bpe.d.ts +4 -3
  196. package/dist/tokeniser/bpe.js +60 -45
  197. package/dist/tokeniser/type.d.ts +11 -0
  198. package/dist/training/Adam.js +2 -2
  199. package/dist/training/AdamExt.js +1 -1
  200. package/dist/training/DatasetBuilder.d.ts +2 -2
  201. package/dist/training/DatasetBuilder.js +32 -36
  202. package/dist/training/FullTrainer.js +1 -1
  203. package/dist/training/Trainer.d.ts +3 -3
  204. package/dist/training/Trainer.js +2 -2
  205. package/dist/training/sparseCrossEntropy.js +5 -5
  206. package/dist/transpose-JawVKyZy.js +36 -0
  207. package/dist/unsorted_segment_sum-LAbmE9G4.js +277 -0
  208. package/dist/utilities/dummy.js +3 -3
  209. package/dist/utilities/multinomialCPU.js +2 -2
  210. package/dist/utilities/packed.d.ts +1 -4
  211. package/dist/utilities/packed.js +10 -745
  212. package/dist/utilities/performance.js +1 -1
  213. package/dist/utilities/profile.js +1 -1
  214. package/dist/utilities/safetensors.js +2 -2
  215. package/dist/utilities/sentences.js +5 -5
  216. package/dist/utilities/weights.js +2 -2
  217. package/dist/{variable-DzfrwYuP.js → variable-DQ9yYgEU.js} +1 -1
  218. package/dist/{webgpu_program-DzaQiqel.js → webgpu_program-CAE4RICo.js} +177 -171
  219. package/dist/{webgpu_util-0_ubCEHJ.js → webgpu_util-BdovYhXr.js} +34 -35
  220. package/dist/zeros-DeiE2zTa.js +13 -0
  221. package/dist/zeros_like-BAz3iKru.js +721 -0
  222. package/package.json +4 -2
  223. package/dist/Reshape-CDVLyVfz.js +0 -16
  224. package/dist/broadcast_to-B0ChcDaz.js +0 -30
  225. package/dist/complex-BBiRlsVq.js +0 -13
  226. package/dist/concat-DmBLPVGC.js +0 -19
  227. package/dist/dropout-B1x1kYMa.js +0 -99
  228. package/dist/expand_dims-ouvfxQ1n.js +0 -13
  229. package/dist/gather-CH9sdacz.js +0 -10
  230. package/dist/index-D6Q1lPZO.js +0 -2157
  231. package/dist/log_sum_exp-D3ftBNY5.js +0 -41
  232. package/dist/mat_mul-C59XWcJd.js +0 -12
  233. package/dist/mod-DESSvHIU.js +0 -12
  234. package/dist/mulmat_packed_gpu-Coh6qbJk.js +0 -55
  235. package/dist/ones-jU9jlQvM.js +0 -15
  236. package/dist/ops-BFDtP6th.js +0 -645
  237. package/dist/pack16-CmVZs6af.js +0 -41
  238. package/dist/patches/PackedTensor.d.ts +0 -12
  239. package/dist/patches/PackedTensor.js +0 -11
  240. package/dist/patches/engine.d.ts +0 -261
  241. package/dist/patches/engine.js +0 -12
  242. package/dist/patches/tape.d.ts +0 -12
  243. package/dist/patches/tape.js +0 -5
  244. package/dist/range-ZZZD60Fx.js +0 -11
  245. package/dist/reciprocal-CrYlsAGD.js +0 -10
  246. package/dist/register_all_kernels-nvj2k7OC.js +0 -12307
  247. package/dist/relu-BYDneVPn.js +0 -10
  248. package/dist/reshape-CaPQzFvz.js +0 -10
  249. package/dist/rope-s4W2XO9B.js +0 -32
  250. package/dist/selu_util-BGPXmd4B.js +0 -303
  251. package/dist/sin-Djs4aQiu.js +0 -16
  252. package/dist/slice-DvovR5wq.js +0 -13
  253. package/dist/split-DBck65sX.js +0 -10
  254. package/dist/squeeze-C00Ipm_7.js +0 -11
  255. package/dist/stack-ChnHwRpX.js +0 -13
  256. package/dist/sum-ywRJj3Zr.js +0 -12
  257. package/dist/tensor-CzmOBsdf.js +0 -909
  258. package/dist/tensor1d-BlUT89BP.js +0 -12
  259. package/dist/tensor_util-DfwaWayG.js +0 -523
  260. package/dist/tile-CR074jmp.js +0 -13
  261. package/dist/transpose-DH4gmHvu.js +0 -38
  262. package/dist/zeros-DBFVbpv5.js +0 -14
@@ -1,12 +1,10 @@
1
- import "./index-D6Q1lPZO.js";
2
- import { r as $ } from "./Reshape-CHdUjC72.js";
3
- import { _ as T, g as E, y as B, $ as F } from "./tensor_util-DfwaWayG.js";
4
- import { G as _, e as G, p as O, s as V } from "./tensor-CzmOBsdf.js";
5
- import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-BsIr9ZNu.js";
6
- import { t as K, m as U } from "./shared-CHhxz-O5.js";
7
- import { c as W } from "./backend_util-B1XRLuq9.js";
8
- import { f as y } from "./gpgpu_math-DsCcikas.js";
9
- import { g as j, b as L } from "./kernel_funcs_utils-CWfOAPGO.js";
1
+ import { aG as T, ab as E, af as O, V, aS as B, Q as F, am as G, aT as K } from "./index-D0RBWjq8.js";
2
+ import { r as $ } from "./Reshape-Ct266DEk.js";
3
+ import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-DofAuy0p.js";
4
+ import { t as U, m as W } from "./shared-B7USJZgw.js";
5
+ import { c as _ } from "./backend_util-C7NWHpv7.js";
6
+ import { f as y } from "./gpgpu_math-CJcbnKPC.js";
7
+ import { g as j, b as L } from "./kernel_funcs_utils-CSaumNDs.js";
10
8
  class w {
11
9
  constructor(s, e) {
12
10
  this.variableNames = ["x"];
@@ -16,7 +14,7 @@ class w {
16
14
  let o = "sumValue += dot(values, ones);";
17
15
  if (e != null) {
18
16
  const p = 1 / e;
19
- o = `sumValue += dot(values * ${_(p) ? p.toPrecision(2) : p}, ones);`;
17
+ o = `sumValue += dot(values * ${T(p) ? p.toPrecision(2) : p}, ones);`;
20
18
  }
21
19
  let u = "";
22
20
  l % t > 0 && (u = `
@@ -186,7 +184,7 @@ class X {
186
184
  function q(a) {
187
185
  const s = [];
188
186
  for (; s.length === 0 || s[s.length - 1].outSize !== 1; ) {
189
- const e = s.length ? s[s.length - 1].outSize : a[1], t = W(e);
187
+ const e = s.length ? s[s.length - 1].outSize : a[1], t = _(e);
190
188
  s.push({
191
189
  inSize: e,
192
190
  windowSize: t,
@@ -205,14 +203,14 @@ function P(a, s, e, t) {
205
203
  }
206
204
  return l;
207
205
  }
208
- class Y {
206
+ class Q {
209
207
  constructor(s, e) {
210
208
  this.variableNames = ["A"];
211
209
  const t = new Array(s.length);
212
210
  for (let r = 0; r < t.length; r++)
213
211
  t[r] = s[e[r]];
214
212
  this.outputShape = t, this.rank = t.length;
215
- const n = y(this.rank), l = H(e);
213
+ const n = y(this.rank), l = Y(e);
216
214
  this.userCode = `
217
215
  void main() {
218
216
  ${n} resRC = getOutputCoords();
@@ -221,7 +219,7 @@ class Y {
221
219
  `;
222
220
  }
223
221
  }
224
- function H(a) {
222
+ function Y(a) {
225
223
  const s = a.length;
226
224
  if (s > 6)
227
225
  throw Error(`Transpose for rank ${s} is not yet supported`);
@@ -230,7 +228,7 @@ function H(a) {
230
228
  t[a[n]] = e[n];
231
229
  return t.join();
232
230
  }
233
- class J {
231
+ class H {
234
232
  constructor(s, e) {
235
233
  this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
236
234
  const t = new Array(s.length);
@@ -263,10 +261,10 @@ class J {
263
261
  }
264
262
  }
265
263
  function D(a, s, e) {
266
- const t = G().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, s) : new Y(a.shape, s);
264
+ const t = E().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new H(a.shape, s) : new Q(a.shape, s);
267
265
  return e.runWebGLProgram(t, [a], a.dtype);
268
266
  }
269
- function Q(a, s, e, t) {
267
+ function J(a, s, e, t) {
270
268
  const n = s, l = a.shape.length, r = O(n, a.shape);
271
269
  let i = r;
272
270
  const c = A(i, l), o = c != null;
@@ -275,15 +273,15 @@ function Q(a, s, e, t) {
275
273
  const [p, h] = N(u.shape, i);
276
274
  let d = p;
277
275
  e && (d = R(p, r));
278
- const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = T(a.dtype), I = P(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
276
+ const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = B(a.dtype), I = P(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
279
277
  return t.disposeIntermediateTensorInfo(x), t.disposeIntermediateTensorInfo(I), o && t.disposeIntermediateTensorInfo(u), m;
280
278
  }
281
279
  function Z(a) {
282
280
  const { inputs: s, backend: e, attrs: t } = a, { x: n } = s, { axis: l, keepDims: r } = t;
283
- return Q(n, l, r, e);
281
+ return J(n, l, r, e);
284
282
  }
285
- const fe = {
286
- kernelName: E,
283
+ const pe = {
284
+ kernelName: F,
287
285
  backendName: "webgl",
288
286
  kernelFunc: Z
289
287
  };
@@ -301,7 +299,7 @@ function te(a) {
301
299
  const I = e.texData.get(d.dataId).values, m = new Array(i);
302
300
  for (let v = 0; v < m.length; v++)
303
301
  m[v] = n.shape[u[v]];
304
- const z = K(I, n.shape, n.dtype, u, m);
302
+ const z = U(I, n.shape, n.dtype, u, m);
305
303
  d = e.makeTensorInfo(m, n.dtype);
306
304
  const M = e.texData.get(d.dataId);
307
305
  M.values = z;
@@ -315,7 +313,7 @@ function te(a) {
315
313
  r && (g = R(f, c));
316
314
  let x;
317
315
  if (h) {
318
- const I = e.texData.get(d.dataId).values, m = U(I, V(S), g, n.dtype);
316
+ const I = e.texData.get(d.dataId).values, m = W(I, V(S), g, n.dtype);
319
317
  x = e.makeTensorInfo(g, n.dtype);
320
318
  const z = e.texData.get(x.dataId);
321
319
  z.values = m;
@@ -323,8 +321,8 @@ function te(a) {
323
321
  x = ee(d, S, g, e);
324
322
  return p && e.disposeIntermediateTensorInfo(d), x;
325
323
  }
326
- const me = {
327
- kernelName: B,
324
+ const he = {
325
+ kernelName: G,
328
326
  backendName: "webgl",
329
327
  kernelFunc: te
330
328
  };
@@ -350,16 +348,16 @@ return a / b;`, se = `
350
348
  }
351
349
 
352
350
  return result;
353
- `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), xe = {
354
- kernelName: F,
351
+ `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
352
+ kernelName: K,
355
353
  backendName: "webgl",
356
354
  kernelFunc: ne
357
355
  };
358
356
  export {
359
357
  P as a,
360
- me as b,
361
- xe as c,
362
- fe as d,
358
+ he as b,
359
+ fe as c,
360
+ pe as d,
363
361
  te as m,
364
362
  ne as r,
365
363
  Z as s,
@@ -0,0 +1,14 @@
1
+ import { V as h, ah as d, y as c, R as m } from "./index-D0RBWjq8.js";
2
+ function i(n) {
3
+ const { inputs: p, attrs: o } = n, { x: e } = p, { shape: r } = o, a = h(e.shape), s = d(r, a), t = h(s);
4
+ return c(a === t, () => `The new shape (${s}) has ${t} elements and the old shape (${e.shape}) has ${a} elements. The new shape and old shape must have the same number of elements.`), n.backend.incRef(e.dataId), { dataId: e.dataId, shape: s, dtype: e.dtype };
5
+ }
6
+ const u = {
7
+ kernelName: m,
8
+ backendName: "webgpu",
9
+ kernelFunc: i
10
+ };
11
+ export {
12
+ u as a,
13
+ i as r
14
+ };
@@ -1,10 +1,8 @@
1
- import "./index-D6Q1lPZO.js";
2
- import { u as C, g as f, a as R, b as g, c as I, d as c, e as u, i as m } from "./gpgpu_math-DsCcikas.js";
3
- import { b as x } from "./tensor_util-DfwaWayG.js";
4
- import { s as l, n as F, a as $ } from "./tensor-CzmOBsdf.js";
1
+ import { R as C, V as c, ah as R, y as f } from "./index-D0RBWjq8.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-CJcbnKPC.js";
5
3
  class S {
6
4
  constructor(t, i) {
7
- this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = t, this.enableShapeUniforms = C(this.outputShape.length);
5
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = t, this.enableShapeUniforms = g(this.outputShape.length);
8
6
  let a = "";
9
7
  for (let e = 0; e < 4; e++) {
10
8
  let o = "thisRC = rc;";
@@ -22,8 +20,8 @@ class S {
22
20
  `;
23
21
  }
24
22
  this.userCode = `
25
- ${b(i, this.enableShapeUniforms)}
26
- ${this.enableShapeUniforms ? f() : R(t)}
23
+ ${v(i, this.enableShapeUniforms)}
24
+ ${this.enableShapeUniforms ? I() : x(t)}
27
25
 
28
26
  void main() {
29
27
  ivec3 rc = getOutputCoords();
@@ -41,41 +39,41 @@ class S {
41
39
  `;
42
40
  }
43
41
  }
44
- function b(s, t) {
42
+ function v(s, t) {
45
43
  return `
46
44
  ivec3 inputCoordsFromReshapedOutCoords(int index) {
47
- ${t ? g(["r", "c", "d"], "inputShape") : I(["r", "c", "d"], s)}
45
+ ${t ? F(["r", "c", "d"], "inputShape") : $(["r", "c", "d"], s)}
48
46
  return ivec3(r, c, d);
49
47
  }
50
48
  `;
51
49
  }
52
- function v(s, t, i) {
50
+ function y(s, t, i) {
53
51
  const a = [
54
- c(s.shape),
55
- ...u(s.shape)
52
+ u(s.shape),
53
+ ...m(s.shape)
56
54
  ], e = {
57
55
  dtype: s.dtype,
58
56
  shape: a,
59
57
  dataId: s.dataId
60
58
  }, o = [
61
- c(t),
62
- ...u(t)
59
+ u(t),
60
+ ...m(t)
63
61
  ], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
64
62
  return { dataId: h.dataId, shape: t, dtype: h.dtype };
65
63
  }
66
- function y(s) {
67
- const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = l(e.shape), n = F(o, p), h = l(n);
68
- $(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
64
+ function b(s) {
65
+ const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = R(o, p), h = c(n);
66
+ f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
69
67
  const d = r.texData.get(e.dataId);
70
- return d.isPacked && !m(e.shape, n) && !(d.texture !== null && m(d.shape, n)) ? v(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
68
+ return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? y(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
71
69
  }
72
- const O = {
73
- kernelName: x,
70
+ const U = {
71
+ kernelName: C,
74
72
  backendName: "webgl",
75
- kernelFunc: y
73
+ kernelFunc: b
76
74
  };
77
75
  export {
78
76
  S as R,
79
- O as a,
80
- y as r
77
+ U as a,
78
+ b as r
81
79
  };
@@ -1,5 +1,5 @@
1
1
  import { GPTConfig } from './models/config';
2
- import { ITokeniser } from './tokeniser/type';
2
+ import { Conversation, ITokeniser } from './tokeniser/type';
3
3
  import { SaveOptions } from './loader/save';
4
4
  import { default as Generator, IGenerateOptions } from './Generator';
5
5
  import { default as Trainer, ITrainerOptions } from './Trainer';
@@ -41,10 +41,11 @@ export default class TeachableLLM {
41
41
  set enableProfiler(value: boolean);
42
42
  getNumParams(): number;
43
43
  trainer(): Trainer;
44
- train(text: string[], options?: ITrainerOptions): Promise<void>;
44
+ train(text: Conversation[][], options?: ITrainerOptions): Promise<void>;
45
45
  trainTokeniser(text: string[]): Promise<number>;
46
46
  generator(): Generator;
47
- generateText(prompt?: string, options?: IGenerateOptions): Promise<string>;
47
+ generateText(prompt: Conversation[], options?: IGenerateOptions): Promise<Conversation[]>;
48
+ generateText(options?: IGenerateOptions): Promise<Conversation[]>;
48
49
  dispose(): void;
49
50
  on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
50
51
  on(event: 'error', listener: (error: Error) => void): void;
@@ -2,28 +2,27 @@ import { defaultConfig as d } from "./models/config.js";
2
2
  import { saveModel as l } from "./loader/save.js";
3
3
  import { loadModel as _ } from "./loader/load.js";
4
4
  import u from "./Generator.js";
5
- import p from "./Trainer.js";
6
- import { E as f } from "./index-DvYrXKkX.js";
5
+ import f from "./Trainer.js";
6
+ import { E as p } from "./index-DvYrXKkX.js";
7
7
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
8
- import "./utilities/packed.js";
9
- import "./index-D6Q1lPZO.js";
8
+ import "./index-D0RBWjq8.js";
9
+ import "./random_width-UGQn4OWb.js";
10
+ import "./zeros_like-BAz3iKru.js";
11
+ import "./index-Cp39cXWe.js";
12
+ import "./dataset-DcjWqUVQ.js";
10
13
  import "./ops/cpu/attentionMask.js";
11
14
  import "./ops/webgl/attentionMask.js";
12
15
  import "./ops/grads/attentionMask.js";
13
- import "./random_width-BVV9HveY.js";
14
- import "./register_all_kernels-nvj2k7OC.js";
15
- import "./index-Cp39cXWe.js";
16
- import "./dataset-D2P7rHAw.js";
17
16
  import "./ops/cpu/rope.js";
18
17
  import "./ops/webgl/rope.js";
19
- import "./rope-s4W2XO9B.js";
18
+ import "./rope-BmZmp9uP.js";
20
19
  import "./ops/cpu/appendCache.js";
21
20
  import "./ops/webgl/appendCache.js";
22
21
  import "./ops/grads/softmax16.js";
23
- import "./matMul16-fEAJ4smh.js";
22
+ import "./matMul16-cDxwemKj.js";
24
23
  import "./ops/webgl/matMul16.js";
25
24
  import "./ops/cpu/matMul16.js";
26
- import "./pack16-CmVZs6af.js";
25
+ import "./pack16-k4jq6aMX.js";
27
26
  import "./ops/transpose16.js";
28
27
  import "./ops/reshape16.js";
29
28
  import "./ops/cpu/qkv.js";
@@ -42,11 +41,11 @@ import g from "./tokeniser/bpe.js";
42
41
  import "./papaparse.min-C0cScC2i.js";
43
42
  import "./jszip.min-Bz5-11Bk.js";
44
43
  import "./ops/cpu/matMulGelu.js";
45
- import "./ops/webgl/matMulGelu.js";
44
+ import "./matMulGelu-B2s_80-H.js";
46
45
  import "./ops/grads/matMulGelu.js";
47
46
  import "./ops/cpu/gelu.js";
48
47
  import "./ops/webgl/gelu.js";
49
- import "./gelu-Bmhopi0J.js";
48
+ import "./gelu-DqTbCx5x.js";
50
49
  import "./ops/webgl/log.js";
51
50
  import "./ops/cpu/adamMoments.js";
52
51
  import "./ops/webgl/adamMoments.js";
@@ -57,7 +56,7 @@ import "./checks/normRMSGrad.js";
57
56
  import k from "./utilities/profile.js";
58
57
  import w from "./models/factory.js";
59
58
  class a {
60
- ee = new f();
59
+ ee = new p();
61
60
  _config;
62
61
  _model;
63
62
  _tokeniser;
@@ -156,7 +155,7 @@ class a {
156
155
  trainer() {
157
156
  if (!this._model || !this._tokeniser)
158
157
  throw new Error("model_or_tokeniser_not_initialized.");
159
- const t = new p(this._model, this._tokeniser);
158
+ const t = new f(this._model, this._tokeniser);
160
159
  return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
161
160
  const o = this.ee.listeners("trainStep");
162
161
  for (const s of o)
@@ -184,7 +183,7 @@ class a {
184
183
  }), t;
185
184
  }
186
185
  generateText(t, e) {
187
- return this.generator().generate(t, e);
186
+ return Array.isArray(t) ? this.generator().generate(t, e) : this.generator().generate([], e);
188
187
  }
189
188
  dispose() {
190
189
  this._model?.dispose();
package/dist/Trainer.d.ts CHANGED
@@ -1,4 +1,4 @@
1
- import { ITokeniser } from './tokeniser/type';
1
+ import { Conversation, ITokeniser } from './tokeniser/type';
2
2
  import { default as EE } from 'eventemitter3';
3
3
  import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
4
4
  import { default as Model, ModelForwardAttributes } from './models/model';
@@ -30,7 +30,7 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
30
30
  constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
31
31
  stop(): void;
32
32
  reset(): void;
33
- prepare(text: string[], options?: ITrainerOptions): Promise<void>;
33
+ prepare(text: Conversation[][], options?: ITrainerOptions): Promise<void>;
34
34
  train(options?: ITrainerOptions): Promise<void>;
35
35
  step(options?: ITrainerOptions): Promise<void>;
36
36
  getLog(): TrainingLogEntry[];
package/dist/Trainer.js CHANGED
@@ -1,6 +1,6 @@
1
- import { E as l } from "./index-DvYrXKkX.js";
2
- import h from "./training/FullTrainer.js";
3
- class m extends l {
1
+ import { E as o } from "./index-DvYrXKkX.js";
2
+ import d from "./training/FullTrainer.js";
3
+ class g extends o {
4
4
  trainer;
5
5
  hasTrained = !1;
6
6
  trainDataset;
@@ -9,7 +9,7 @@ class m extends l {
9
9
  log = [];
10
10
  progress = null;
11
11
  constructor(t, e) {
12
- super(), this.trainer = new h(t, e, 1e-3);
12
+ super(), this.trainer = new d(t, e, 1e-3);
13
13
  }
14
14
  stop() {
15
15
  this.trainer.stop();
@@ -22,7 +22,7 @@ class m extends l {
22
22
  t,
23
23
  e?.batchSize || 32,
24
24
  e?.validationSplit || 0.1
25
- ), i = t.reduce((r, n) => r + n.length, 0) * (1 - (e?.validationSplit || 0));
25
+ ), i = t.reduce((r, n) => r + n.reduce((l, h) => l + h.content.length, 0), 0) * (1 - (e?.validationSplit || 0));
26
26
  this.trainDataset = a, this.validationDataset = s, this.totalSamples = i;
27
27
  }
28
28
  async train(t) {
@@ -91,5 +91,5 @@ class m extends l {
91
91
  }
92
92
  }
93
93
  export {
94
- m as default
94
+ g as default
95
95
  };
@@ -1,4 +1,4 @@
1
- import { a as c } from "./tensor-CzmOBsdf.js";
1
+ import { y as c } from "./index-D0RBWjq8.js";
2
2
  function i(e, n) {
3
3
  for (let t = 0; t < e.length; ++t)
4
4
  if (e[e.length - t - 1] !== n - 1 - t)
package/dist/backend.js CHANGED
@@ -1,9 +1,9 @@
1
- import { g as o, s as e, r as s } from "./index-D6Q1lPZO.js";
1
+ import { g as o, s as e, r as s } from "./index-D0RBWjq8.js";
2
2
  async function c(t, a) {
3
3
  if (o() !== t) {
4
4
  if (t === "webgpu") {
5
5
  const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
6
- i(a), await import("./index-DRyE072i.js"), await import("./ops/webgpu/index.js");
6
+ i(a), await import("./index-Dj5TkmPY.js"), await import("./ops/webgpu/index.js");
7
7
  }
8
8
  await e(t), await s(), console.log(`Backend set to ${t}`);
9
9
  }