@genai-fi/nanogpt 0.17.4 → 0.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. package/dist/Generator.d.ts +2 -15
  2. package/dist/Generator.js +45 -34
  3. package/dist/{RealDiv-CGwv0liw.js → RealDiv-ioj6Z-ox.js} +9 -9
  4. package/dist/{Reshape-BW__R4mZ.js → Reshape-BZC-ebeR.js} +7 -7
  5. package/dist/{Reshape-CPBkTIH2.js → Reshape-pwprEaej.js} +1 -1
  6. package/dist/TeachableLLM.d.ts +3 -8
  7. package/dist/TeachableLLM.js +61 -44
  8. package/dist/Trainer.d.ts +6 -4
  9. package/dist/Trainer.js +107 -92
  10. package/dist/{axis_util-GTVlo58H.js → axis_util-QWWgLjut.js} +1 -1
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-GaFarB78.js → backend_util-qwSFfxYx.js} +21 -21
  13. package/dist/{backend_webgpu-BqASlsbV.js → backend_webgpu-DI2wXEC2.js} +8 -8
  14. package/dist/{broadcast_to-eS93CCN_.js → broadcast_to-C_EJTVTZ.js} +2 -2
  15. package/dist/checks/appendCache.js +2 -2
  16. package/dist/checks/attentionMask.js +5 -5
  17. package/dist/checks/gelu.js +2 -2
  18. package/dist/checks/matMulGelu.js +2 -2
  19. package/dist/checks/normRMS.js +6 -6
  20. package/dist/checks/normRMSGrad.js +3 -3
  21. package/dist/checks/packUnpack.js +6 -6
  22. package/dist/checks/qkv.js +2 -2
  23. package/dist/checks/rope.js +2 -2
  24. package/dist/{clip_by_value-DDA7rrcT.js → clip_by_value-CLAD4h_I.js} +1 -1
  25. package/dist/complex-3DpPEG9B.js +11 -0
  26. package/dist/{concat-CAQpCret.js → concat-Dqk7Xk7h.js} +5 -5
  27. package/dist/{concat_util-D18dJ4fD.js → concat_util-C1Mxe27t.js} +1 -1
  28. package/dist/{dataset-CGGp1z9P.js → dataset-DlqAN81i.js} +3 -3
  29. package/dist/{dropout_util--NxWuYg2.js → dropout_util-N0z8Os-K.js} +1 -1
  30. package/dist/{expand_dims-Bkd1YD5x.js → expand_dims-D0rBtgT1.js} +4 -4
  31. package/dist/{exports_initializers-CYzKLjN7.js → exports_initializers-DIOZQt_L.js} +1 -1
  32. package/dist/{floor-BQtb-Azg.js → floor-CymuCmTO.js} +1 -1
  33. package/dist/{gather-qIqEqaGn.js → gather-DEyjXNb1.js} +1 -1
  34. package/dist/{gelu-B220X1Go.js → gelu-DpTCC3eB.js} +1 -1
  35. package/dist/{gpgpu_math-BwvV12df.js → gpgpu_math-3bCb5ooU.js} +25 -25
  36. package/dist/{index-CjOWnMXP.js → index-BQvB7LCC.js} +15 -15
  37. package/dist/{index-CUXkjxiT.js → index-DSGwv2Yx.js} +33 -33
  38. package/dist/inference/types.d.ts +16 -0
  39. package/dist/inference/types.js +1 -0
  40. package/dist/{kernel_funcs_utils-pq0CK9co.js → kernel_funcs_utils-DGqzNlHT.js} +6 -6
  41. package/dist/layers/BaseLayer.js +4 -4
  42. package/dist/layers/CausalSelfAttention.js +6 -6
  43. package/dist/layers/LoRA.js +4 -4
  44. package/dist/layers/MLP.js +4 -4
  45. package/dist/layers/PositionEmbedding.js +5 -5
  46. package/dist/layers/RMSNorm.js +3 -3
  47. package/dist/layers/RoPECache.js +4 -4
  48. package/dist/layers/TiedEmbedding.js +6 -6
  49. package/dist/layers/TransformerBlock.js +1 -1
  50. package/dist/layers/WeightStore.js +2 -2
  51. package/dist/loader/load.d.ts +2 -8
  52. package/dist/loader/loadTransformers.d.ts +2 -8
  53. package/dist/loader/loadTransformers.js +13 -11
  54. package/dist/loader/newZipLoad.d.ts +2 -8
  55. package/dist/loader/newZipLoad.js +25 -10
  56. package/dist/loader/oldZipLoad.js +13 -13
  57. package/dist/loader/save.d.ts +9 -2
  58. package/dist/loader/save.js +64 -55
  59. package/dist/loader/types.d.ts +29 -1
  60. package/dist/main.d.ts +2 -0
  61. package/dist/main.js +45 -43
  62. package/dist/{matMul16-BcVC_E62.js → matMul16-BIT70Vya.js} +3 -3
  63. package/dist/{matMulGelu-JNLZqKQp.js → matMulGelu-CsZnh18H.js} +18 -18
  64. package/dist/mat_mul-DP86qZtZ.js +11 -0
  65. package/dist/mod-BXjLYwvM.js +11 -0
  66. package/dist/models/NanoGPTV1.js +2 -2
  67. package/dist/models/NanoGPTV2.js +2 -2
  68. package/dist/models/model.d.ts +3 -2
  69. package/dist/models/model.js +13 -13
  70. package/dist/{not_equal-hurPF26l.js → not_equal-CkQKkKZy.js} +15 -15
  71. package/dist/{ones-BytntneX.js → ones-DbVB5N58.js} +3 -3
  72. package/dist/ops/adamAdjust.js +3 -3
  73. package/dist/ops/adamMoments.js +3 -3
  74. package/dist/ops/add16.js +1 -1
  75. package/dist/ops/appendCache.js +6 -6
  76. package/dist/ops/attentionMask.js +3 -3
  77. package/dist/ops/concat16.js +3 -3
  78. package/dist/ops/cpu/adamAdjust.js +9 -9
  79. package/dist/ops/cpu/adamMoments.js +5 -5
  80. package/dist/ops/cpu/appendCache.js +2 -2
  81. package/dist/ops/cpu/attentionMask.js +6 -6
  82. package/dist/ops/cpu/fusedSoftmax.js +4 -4
  83. package/dist/ops/cpu/gatherSub.js +5 -5
  84. package/dist/ops/cpu/gelu.js +4 -4
  85. package/dist/ops/cpu/matMul16.js +2 -2
  86. package/dist/ops/cpu/matMulGelu.js +7 -7
  87. package/dist/ops/cpu/matMulMul.js +2 -2
  88. package/dist/ops/cpu/mulDropout.js +5 -5
  89. package/dist/ops/cpu/normRMS.js +1 -1
  90. package/dist/ops/cpu/qkv.js +3 -3
  91. package/dist/ops/cpu/rope.js +5 -5
  92. package/dist/ops/cpu/scatterSub.js +5 -5
  93. package/dist/ops/dot16.js +2 -2
  94. package/dist/ops/dropout.js +6 -6
  95. package/dist/ops/dropout16.js +1 -1
  96. package/dist/ops/gatherSub.js +1 -1
  97. package/dist/ops/gelu.js +2 -2
  98. package/dist/ops/globalNorm.js +7 -7
  99. package/dist/ops/grads/add16.js +1 -1
  100. package/dist/ops/grads/attentionMask.js +2 -2
  101. package/dist/ops/grads/dropout16.js +1 -1
  102. package/dist/ops/grads/gelu.js +2 -2
  103. package/dist/ops/grads/matMul16.js +3 -3
  104. package/dist/ops/grads/matMulGelu.js +1 -1
  105. package/dist/ops/grads/mul16.js +1 -1
  106. package/dist/ops/grads/normRMS.js +7 -7
  107. package/dist/ops/grads/pack16.js +3 -3
  108. package/dist/ops/grads/qkv.js +11 -11
  109. package/dist/ops/grads/rope.js +2 -2
  110. package/dist/ops/grads/softmax16.js +1 -1
  111. package/dist/ops/grads/unpack16.js +2 -2
  112. package/dist/ops/matMul16.js +3 -3
  113. package/dist/ops/matMulGelu.js +6 -6
  114. package/dist/ops/matMulMul.js +3 -3
  115. package/dist/ops/mul16.js +1 -1
  116. package/dist/ops/mulDrop.js +3 -3
  117. package/dist/ops/normRMS.js +4 -4
  118. package/dist/ops/pack16.js +2 -2
  119. package/dist/ops/qkv.js +3 -3
  120. package/dist/ops/reshape16.js +6 -6
  121. package/dist/ops/rope.js +2 -2
  122. package/dist/ops/scatterSub.js +1 -1
  123. package/dist/ops/slice16.js +2 -2
  124. package/dist/ops/softmax16.js +1 -1
  125. package/dist/ops/sub16.js +1 -1
  126. package/dist/ops/sum16.js +6 -6
  127. package/dist/ops/transpose16.js +3 -3
  128. package/dist/ops/unpack16.js +2 -2
  129. package/dist/ops/webgl/adamAdjust.js +2 -2
  130. package/dist/ops/webgl/adamMoments.js +1 -1
  131. package/dist/ops/webgl/appendCache.js +1 -1
  132. package/dist/ops/webgl/attentionMask.js +1 -1
  133. package/dist/ops/webgl/dropout16.js +1 -1
  134. package/dist/ops/webgl/fusedSoftmax.js +7 -7
  135. package/dist/ops/webgl/gatherSub.js +3 -3
  136. package/dist/ops/webgl/gelu.js +2 -2
  137. package/dist/ops/webgl/log.js +3 -3
  138. package/dist/ops/webgl/matMul16.js +13 -13
  139. package/dist/ops/webgl/matMulGelu.js +4 -4
  140. package/dist/ops/webgl/matMulMul.js +2 -2
  141. package/dist/ops/webgl/mulDropout.js +1 -1
  142. package/dist/ops/webgl/normRMS.js +2 -2
  143. package/dist/ops/webgl/qkv.js +1 -1
  144. package/dist/ops/webgl/rope.js +1 -1
  145. package/dist/ops/webgl/scatterSub.js +2 -2
  146. package/dist/ops/webgpu/adamAdjust.js +3 -3
  147. package/dist/ops/webgpu/adamMoments.js +3 -3
  148. package/dist/ops/webgpu/add16.js +6 -6
  149. package/dist/ops/webgpu/appendCache.js +3 -3
  150. package/dist/ops/webgpu/attentionMask.js +2 -2
  151. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  152. package/dist/ops/webgpu/clipScale.js +7 -7
  153. package/dist/ops/webgpu/concat16.js +5 -5
  154. package/dist/ops/webgpu/dropout16.js +6 -6
  155. package/dist/ops/webgpu/gatherSub.js +3 -3
  156. package/dist/ops/webgpu/gelu.js +8 -8
  157. package/dist/ops/webgpu/matMul16.js +16 -16
  158. package/dist/ops/webgpu/matMul16_program.js +2 -2
  159. package/dist/ops/webgpu/mul16.js +5 -5
  160. package/dist/ops/webgpu/norm2.js +1 -1
  161. package/dist/ops/webgpu/normRMS.js +2 -2
  162. package/dist/ops/webgpu/normRMSGrad.js +4 -4
  163. package/dist/ops/webgpu/pack16.js +4 -4
  164. package/dist/ops/webgpu/pack16_program.js +2 -2
  165. package/dist/ops/webgpu/qkv.js +2 -2
  166. package/dist/ops/webgpu/rope.js +3 -3
  167. package/dist/ops/webgpu/scatterSub.js +3 -3
  168. package/dist/ops/webgpu/slice16.js +4 -4
  169. package/dist/ops/webgpu/softmax16.js +4 -4
  170. package/dist/ops/webgpu/softmax16_program.js +2 -2
  171. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  172. package/dist/ops/webgpu/softmax16grad.js +4 -4
  173. package/dist/ops/webgpu/sub16.js +6 -6
  174. package/dist/ops/webgpu/sum16.js +3 -3
  175. package/dist/ops/webgpu/transpose16.js +8 -8
  176. package/dist/ops/webgpu/transpose16_program.js +2 -2
  177. package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
  178. package/dist/ops/webgpu/unpack16.js +3 -3
  179. package/dist/ops/webgpu/utils/binary_op.js +3 -3
  180. package/dist/ops/webgpu/utils/reductions.js +5 -5
  181. package/dist/{ops-CsXeTq1P.js → ops-CURIZSVt.js} +100 -100
  182. package/dist/{pack16-bqltoUlR.js → pack16-WlOSOuZA.js} +2 -2
  183. package/dist/patches/webgpu_backend.js +6 -6
  184. package/dist/patches/webgpu_base.js +1 -1
  185. package/dist/patches/webgpu_program.js +2 -2
  186. package/dist/{random_normal-IBRrha8a.js → random_normal-CIm8lk2-.js} +1 -1
  187. package/dist/{random_width-DN5ZtQkM.js → random_width-B_fVXhGx.js} +131 -131
  188. package/dist/{range-C-CjF-LI.js → range-BDxO73mk.js} +1 -1
  189. package/dist/{readers-iz5u3HBo.js → readers-17HLdxVM.js} +2 -2
  190. package/dist/relu-DTvZKBsZ.js +9 -0
  191. package/dist/{reshape-BDOuCSNW.js → reshape-BIN71H3p.js} +1 -1
  192. package/dist/{resize_nearest_neighbor-BojqlfRe.js → resize_nearest_neighbor-C6_0dAnK.js} +41 -41
  193. package/dist/{rope-0j_f1TPm.js → rope-CC5RjmKU.js} +4 -4
  194. package/dist/{scatter_nd_util-ByNJaL6I.js → scatter_nd_util-C-x73Cj6.js} +1 -1
  195. package/dist/{segment_util-Dasb2Zaf.js → segment_util-4zuHV5IG.js} +2 -2
  196. package/dist/{selu_util-BLhIqRkw.js → selu_util-BXdhy_W6.js} +5 -5
  197. package/dist/{shared-CagdqkLh.js → shared-DRWDyk9w.js} +6 -6
  198. package/dist/{shared-3agzAqQ_.js → shared-zTaJ5siv.js} +1 -1
  199. package/dist/slice-BvItlgXu.js +12 -0
  200. package/dist/{slice_util-CC35pLmT.js → slice_util-DPY56GzQ.js} +5 -5
  201. package/dist/{softmax-D4q1LJN7.js → softmax-BLGJqdwx.js} +1 -1
  202. package/dist/split-BN9LkEgS.js +9 -0
  203. package/dist/{squeeze-ho4wLUek.js → squeeze-O_YWJpw_.js} +2 -2
  204. package/dist/{stack-DudVrtmG.js → stack-z6QE7kmP.js} +1 -1
  205. package/dist/{step-BTxPtq1r.js → step-DQY6_ABw.js} +4 -4
  206. package/dist/{sum-BpiwSWvg.js → sum-D39FeU5h.js} +3 -3
  207. package/dist/{tensor-BWFldCso.js → tensor-D8e0Gd7c.js} +1 -1
  208. package/dist/{tensor1d-LMGMIUlr.js → tensor1d-BMl0eZYV.js} +1 -1
  209. package/dist/{tensor2d-BnXMKScO.js → tensor2d-DTtQ1QcT.js} +1 -1
  210. package/dist/{tensor4d-C6UCG_u8.js → tensor4d-Dj4rDssL.js} +1 -1
  211. package/dist/{tfjs_backend-BGnG-ppu.js → tfjs_backend-Bk3PmK91.js} +65 -65
  212. package/dist/{tile-CFy-xTO6.js → tile-CsWlVKKz.js} +1 -1
  213. package/dist/tokeniser/BaseTokeniser.d.ts +4 -1
  214. package/dist/tokeniser/BaseTokeniser.js +21 -5
  215. package/dist/tokeniser/CharTokeniser.d.ts +1 -1
  216. package/dist/tokeniser/CharTokeniser.js +62 -50
  217. package/dist/tokeniser/bpe.d.ts +1 -1
  218. package/dist/tokeniser/bpe.js +41 -35
  219. package/dist/tokeniser/type.d.ts +3 -1
  220. package/dist/training/AdamW.d.ts +3 -0
  221. package/dist/training/AdamW.js +59 -30
  222. package/dist/training/BasicTrainer.d.ts +1 -0
  223. package/dist/training/BasicTrainer.js +112 -92
  224. package/dist/training/DatasetBuilder.js +3 -3
  225. package/dist/training/Evaluator.js +2 -2
  226. package/dist/training/LRScheduler.d.ts +1 -0
  227. package/dist/training/LRScheduler.js +18 -12
  228. package/dist/training/PreTrainer.js +3 -3
  229. package/dist/training/SFTDatasetBuilder.js +3 -3
  230. package/dist/training/SFTTrainer.js +1 -1
  231. package/dist/training/orthoGrad.js +1 -1
  232. package/dist/training/sparseCrossEntropy.js +30 -30
  233. package/dist/training/types.d.ts +5 -3
  234. package/dist/training/validation.js +13 -13
  235. package/dist/{transpose-9kRxIXWR.js → transpose-Qxz-4os3.js} +7 -7
  236. package/dist/{unsorted_segment_sum-DJvk5xnh.js → unsorted_segment_sum-BfFVV9Zm.js} +20 -20
  237. package/dist/utilities/datasetID.d.ts +2 -0
  238. package/dist/utilities/datasetID.js +21 -0
  239. package/dist/utilities/dummy.js +6 -6
  240. package/dist/utilities/multinomialCPU.js +2 -2
  241. package/dist/utilities/packed.js +1 -1
  242. package/dist/utilities/performance.js +1 -1
  243. package/dist/utilities/profile.js +1 -1
  244. package/dist/utilities/safetensors.js +2 -2
  245. package/dist/utilities/sentences.js +5 -5
  246. package/dist/utilities/weights.js +2 -2
  247. package/dist/{variable-Ck482e3n.js → variable-SSATClyt.js} +1 -1
  248. package/dist/{webgpu_program-B4HmApL1.js → webgpu_program-CbjdYLYk.js} +1 -1
  249. package/dist/{webgpu_util-DYlGSwOJ.js → webgpu_util-DuofJBMo.js} +7 -7
  250. package/dist/{zeros-DvZpK8s6.js → zeros-Bw0puq_w.js} +2 -2
  251. package/dist/{zeros_like-CWjDdwr-.js → zeros_like-rOHr54NY.js} +69 -69
  252. package/package.json +3 -3
  253. package/dist/complex-DI35Q-gW.js +0 -11
  254. package/dist/mat_mul-DhG0Newp.js +0 -11
  255. package/dist/mod-CSdCpRjf.js +0 -11
  256. package/dist/relu-J_X6MUzx.js +0 -9
  257. package/dist/slice-BzS11Qh0.js +0 -12
  258. package/dist/split-C2Sj255c.js +0 -9
@@ -1,26 +1,26 @@
1
- import { o as g, q as w, x as A, ag as Te, i as ke, j as M, m as Q, E as J, B as ae, U as ue, _ as le, a2 as fe, bb as he, aD as pe, bc as Ie, t as S, L as $e, a_ as Ee } from "./index-CUXkjxiT.js";
2
- import { t as be } from "./tensor1d-LMGMIUlr.js";
3
- import { r as Le } from "./random_normal-IBRrha8a.js";
4
- import { s as P } from "./slice-BzS11Qh0.js";
5
- import { r as c } from "./reshape-BDOuCSNW.js";
6
- import { g as Ne } from "./gather-qIqEqaGn.js";
7
- import { e as Pe } from "./step-BTxPtq1r.js";
8
- import { c as Ce } from "./clip_by_value-DDA7rrcT.js";
9
- import { t as Fe } from "./tile-CFy-xTO6.js";
10
- import { s as ve, b as Me, c as je, g as Ue } from "./selu_util-BLhIqRkw.js";
11
- import { m as $ } from "./mat_mul-DhG0Newp.js";
12
- import { t as Ve } from "./transpose-9kRxIXWR.js";
13
- import { c as j } from "./concat-CAQpCret.js";
14
- import { g as xe, r as Be } from "./dropout_util--NxWuYg2.js";
15
- import { f as Ge } from "./floor-BQtb-Azg.js";
16
- function qe(e) {
1
+ import { o as g, n as w, v as A, ag as Te, d as ke, h as M, m as Q, E as q, z as ae, N as ue, _ as le, a2 as fe, bb as he, aD as pe, bc as Ie, t as S, J as $e, a_ as Ee } from "./index-DSGwv2Yx.js";
2
+ import { t as be } from "./tensor1d-BMl0eZYV.js";
3
+ import { r as Le } from "./random_normal-CIm8lk2-.js";
4
+ import { s as P } from "./slice-BvItlgXu.js";
5
+ import { r as c } from "./reshape-BIN71H3p.js";
6
+ import { g as Ne } from "./gather-DEyjXNb1.js";
7
+ import { e as Pe } from "./step-DQY6_ABw.js";
8
+ import { c as ve } from "./clip_by_value-CLAD4h_I.js";
9
+ import { t as Ce } from "./tile-CsWlVKKz.js";
10
+ import { s as Fe, b as Me, c as je, g as Ve } from "./selu_util-BXdhy_W6.js";
11
+ import { m as $ } from "./mat_mul-DP86qZtZ.js";
12
+ import { t as Ue } from "./transpose-Qxz-4os3.js";
13
+ import { c as j } from "./concat-Dqk7Xk7h.js";
14
+ import { g as xe, r as Be } from "./dropout_util-N0z8Os-K.js";
15
+ import { f as Ge } from "./floor-CymuCmTO.js";
16
+ function Je(e) {
17
17
  return j(
18
18
  e,
19
19
  0
20
20
  /* axis */
21
21
  );
22
22
  }
23
- const Je = /* @__PURE__ */ g({ concat1d_: qe });
23
+ const qe = /* @__PURE__ */ g({ concat1d_: Je });
24
24
  function Ke(e, n) {
25
25
  return j(e, n);
26
26
  }
@@ -52,7 +52,7 @@ function en(e, n, t) {
52
52
  const s = w(e, "x", "slice4d");
53
53
  return A(s.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${s.rank} tensor`), P(s, n, t);
54
54
  }
55
- const V = /* @__PURE__ */ g({ slice4d_: en });
55
+ const U = /* @__PURE__ */ g({ slice4d_: en });
56
56
  function nn(e, n, t, s) {
57
57
  const r = w(e, "x", "dropout");
58
58
  if (A(r.dtype === "float32", () => `x has to be a floating point tensor since it's going to be scaled, but got a ${r.dtype} tensor instead.`), A(n >= 0 && n < 1, () => `rate must be a float in the range [0, 1), but got ${n}.`), n === 0)
@@ -62,7 +62,7 @@ function nn(e, n, t, s) {
62
62
  }
63
63
  const tn = /* @__PURE__ */ g({ dropout_: nn });
64
64
  function sn({ a: e, b: n, transposeA: t = !1, transposeB: s = !1, bias: r, activation: o = "linear", preluActivationWeights: i, leakyreluAlpha: f = 0.2 }) {
65
- if (ve(J.state.gradientDepth, o) === !1) {
65
+ if (Fe(q.state.gradientDepth, o) === !1) {
66
66
  let _ = $(e, n, t, s);
67
67
  return r != null && (_ = M(_, r)), Me(_, o, i, f);
68
68
  }
@@ -70,38 +70,38 @@ function sn({ a: e, b: n, transposeA: t = !1, transposeB: s = !1, bias: r, activ
70
70
  [a, u] = ae(a, u);
71
71
  const m = t ? a.shape[a.rank - 2] : a.shape[a.rank - 1], d = s ? u.shape[u.rank - 1] : u.shape[u.rank - 2], T = t ? a.shape[a.rank - 1] : a.shape[a.rank - 2], h = s ? u.shape[u.rank - 2] : u.shape[u.rank - 1], ne = a.shape.slice(0, -2), I = u.shape.slice(0, -2), te = ue(ne), se = ue(I);
72
72
  A(m === d, () => `Error in fused matMul: inner shapes (${m}) and (${d}) of Tensors with shapes ${a.shape} and ${u.shape} and transposeA=${t} and transposeB=${s} must match.`);
73
- const B = le(a.shape.slice(0, -2), u.shape.slice(0, -2)).concat([T, h]), G = t ? c(a, [te, m, T]) : c(a, [te, T, m]), q = s ? c(u, [se, h, d]) : c(u, [se, d, h]);
73
+ const B = le(a.shape.slice(0, -2), u.shape.slice(0, -2)).concat([T, h]), G = t ? c(a, [te, m, T]) : c(a, [te, T, m]), J = s ? c(u, [se, h, d]) : c(u, [se, d, h]);
74
74
  let E;
75
75
  r != null && (E = w(r, "bias", "fused matMul"), [E] = ae(E, a), le(B, E.shape));
76
76
  let re;
77
77
  i != null && (re = w(i, "prelu weights", "fused matMul"));
78
- const oe = (_, C) => {
79
- const [O, D, y, U] = C, k = je(c(_, y.shape), y, o);
78
+ const oe = (_, v) => {
79
+ const [O, D, y, V] = v, k = je(c(_, y.shape), y, o);
80
80
  let b, L;
81
81
  if (!t && !s ? (b = $(k, D, !1, !0), L = $(O, k, !0, !1)) : !t && s ? (b = $(k, D, !1, !1), L = $(k, O, !0, !1)) : t && !s ? (b = $(D, k, !1, !0), L = $(O, k, !1, !1)) : (b = $(D, k, !0, !0), L = $(k, O, !0, !0)), r != null) {
82
- const ye = Ue(U, k);
82
+ const ye = Ve(V, k);
83
83
  return [b, L, ye];
84
84
  } else
85
85
  return [b, L];
86
86
  }, ie = {
87
87
  a: G,
88
- b: q,
88
+ b: J,
89
89
  bias: E,
90
90
  preluActivationWeights: re
91
91
  }, ce = { transposeA: t, transposeB: s, activation: o, leakyreluAlpha: f };
92
- return r == null ? fe((C, O, D) => {
92
+ return r == null ? fe((v, O, D) => {
93
93
  const y = (
94
94
  // tslint:disable-next-line: no-unnecessary-type-assertion
95
- J.runKernel(he, ie, ce)
95
+ q.runKernel(he, ie, ce)
96
96
  );
97
- return D([C, O, y]), { value: c(y, B), gradFunc: oe };
98
- })(G, q) : fe((C, O, D, y) => {
99
- const U = (
97
+ return D([v, O, y]), { value: c(y, B), gradFunc: oe };
98
+ })(G, J) : fe((v, O, D, y) => {
99
+ const V = (
100
100
  // tslint:disable-next-line: no-unnecessary-type-assertion
101
- J.runKernel(he, ie, ce)
101
+ q.runKernel(he, ie, ce)
102
102
  );
103
- return y([C, O, U, D]), { value: c(U, B), gradFunc: oe };
104
- })(G, q, E);
103
+ return y([v, O, V, D]), { value: c(V, B), gradFunc: oe };
104
+ })(G, J, E);
105
105
  }
106
106
  const de = /* @__PURE__ */ g({ fusedMatMul_: sn });
107
107
  class Ae extends Error {
@@ -119,9 +119,9 @@ class l extends Error {
119
119
  super(n), Object.setPrototypeOf(this, l.prototype);
120
120
  }
121
121
  }
122
- class v extends Error {
122
+ class F extends Error {
123
123
  constructor(n) {
124
- super(n), Object.setPrototypeOf(this, v.prototype);
124
+ super(n), Object.setPrototypeOf(this, F.prototype);
125
125
  }
126
126
  }
127
127
  class ee extends Error {
@@ -144,16 +144,16 @@ function me(e, n) {
144
144
  if (!e)
145
145
  throw new ee(n);
146
146
  }
147
- function Cn(e, n) {
147
+ function vn(e, n) {
148
148
  let t = 0;
149
149
  for (const s of e)
150
150
  s === n && t++;
151
151
  return t;
152
152
  }
153
- function Fn(e) {
153
+ function Cn(e) {
154
154
  return e.length === 1 ? e[0] : e;
155
155
  }
156
- function vn(e) {
156
+ function Fn(e) {
157
157
  return Array.isArray(e) ? e : [e];
158
158
  }
159
159
  function Mn(e) {
@@ -164,7 +164,7 @@ function jn(e) {
164
164
  return e.length <= 1 || e.indexOf("_") === -1 ? e : e.replace(/[_]+(\w|$)/g, (n, t) => t.toUpperCase());
165
165
  }
166
166
  let p = {};
167
- function Un(e) {
167
+ function Vn(e) {
168
168
  if (e == null)
169
169
  return null;
170
170
  const n = {};
@@ -182,7 +182,7 @@ function W(e) {
182
182
  }
183
183
  }
184
184
  }
185
- function Vn(e, n = {}, t = {}, s = "object", r = !1) {
185
+ function Un(e, n = {}, t = {}, s = "object", r = !1) {
186
186
  if (typeof e == "string") {
187
187
  const o = e;
188
188
  let i;
@@ -255,7 +255,7 @@ function x(e, n, t) {
255
255
  if (t != null && e.indexOf(t) < 0)
256
256
  throw new l(`${t} is not a valid ${n}. Valid values are ${e} or null/undefined.`);
257
257
  }
258
- function qn(e, n, t = 0, s = 1 / 0) {
258
+ function Jn(e, n, t = 0, s = 1 / 0) {
259
259
  return me(t >= 0), me(s >= t), Array.isArray(e) && e.length >= t && e.length <= s && e.every((r) => typeof r === n);
260
260
  }
261
261
  function on(e, n) {
@@ -264,7 +264,7 @@ function on(e, n) {
264
264
  function Oe(e) {
265
265
  return e === null ? "null" : Array.isArray(e) ? "[" + e.map((n) => Oe(n)).join(",") + "]" : typeof e == "string" ? `"${e}"` : `${e}`;
266
266
  }
267
- function Jn(e, n, t) {
267
+ function qn(e, n, t) {
268
268
  let s = t != null ? t() : pe(), r;
269
269
  return (...i) => {
270
270
  const f = t != null ? t() : pe();
@@ -288,18 +288,18 @@ function Wn(e) {
288
288
  function Yn(e) {
289
289
  x(ln, "PoolMode", e);
290
290
  }
291
- const F = [], ge = "/";
291
+ const C = [], ge = "/";
292
292
  function Hn(e, n) {
293
- F.push(e);
293
+ C.push(e);
294
294
  try {
295
295
  const t = n();
296
- return F.pop(), t;
296
+ return C.pop(), t;
297
297
  } catch (t) {
298
- throw F.pop(), t;
298
+ throw C.pop(), t;
299
299
  }
300
300
  }
301
301
  function hn() {
302
- return F.length === 0 ? "" : F.join(ge) + ge;
302
+ return C.length === 0 ? "" : C.join(ge) + ge;
303
303
  }
304
304
  function Qn(e) {
305
305
  if (!De(e))
@@ -401,7 +401,7 @@ function R(e, n, t) {
401
401
  case 3:
402
402
  return z(e, [n, 0, 0], [t, e.shape[1], e.shape[2]]);
403
403
  case 4:
404
- return V(e, [n, 0, 0, 0], [t, e.shape[1], e.shape[2], e.shape[3]]);
404
+ return U(e, [n, 0, 0, 0], [t, e.shape[1], e.shape[2], e.shape[3]]);
405
405
  case 5:
406
406
  return P(e, [n, 0, 0, 0, 0], [
407
407
  t,
@@ -434,7 +434,7 @@ function Z(e, n, t) {
434
434
  case 3:
435
435
  return z(e, [0, 0, n], [e.shape[0], e.shape[1], t]);
436
436
  case 4:
437
- return V(e, [0, 0, 0, n], [e.shape[0], e.shape[1], e.shape[2], t]);
437
+ return U(e, [0, 0, 0, n], [e.shape[0], e.shape[1], e.shape[2], t]);
438
438
  default:
439
439
  throw new l(`sliceAlongLastAxis() received an unsupported tensor rank: ${e.rank}`);
440
440
  }
@@ -470,9 +470,9 @@ function at(e, n, t, s) {
470
470
  case 1:
471
471
  return R(e, n, t);
472
472
  case 2:
473
- return V(e, [0, n, 0, 0], [e.shape[0], t, e.shape[2], e.shape[3]]);
473
+ return U(e, [0, n, 0, 0], [e.shape[0], t, e.shape[2], e.shape[3]]);
474
474
  case 3:
475
- return V(e, [0, 0, n, 0], [e.shape[0], e.shape[1], t, e.shape[3]]);
475
+ return U(e, [0, 0, n, 0], [e.shape[0], e.shape[1], t, e.shape[3]]);
476
476
  case 4:
477
477
  return Z(e, n, t);
478
478
  default:
@@ -490,7 +490,7 @@ function ut(e, n = -1) {
490
490
  function lt(e, n) {
491
491
  switch (e.rank) {
492
492
  case 1:
493
- return Je([e, n]);
493
+ return qe([e, n]);
494
494
  case 2:
495
495
  return Re([e, n], 0);
496
496
  case 3:
@@ -504,18 +504,18 @@ function lt(e, n) {
504
504
  function mn(e, n) {
505
505
  if (Array.isArray(n) || (n = [n]), e.rank !== n.length)
506
506
  throw new l(`The length of input n (${n.length}) does not match the number of dimensions in input x (${e.rank})`);
507
- return Fe(e, n);
507
+ return Ce(e, n);
508
508
  }
509
509
  function ft(e, n = 0, t = 1, s, r) {
510
510
  return Le(e, n, t, s, r);
511
511
  }
512
512
  function ht(e, n, t, s) {
513
513
  if (e.rank < 2 || n.rank < 2)
514
- throw new v(`dot requires both inputs to be rank >= 2 but got x shape = ${e.shape} and y shape = ${n.shape}`);
514
+ throw new F(`dot requires both inputs to be rank >= 2 but got x shape = ${e.shape} and y shape = ${n.shape}`);
515
515
  if (n.rank >= 3) {
516
516
  const r = e.shape.slice(-1)[0], o = n.shape.slice(-2)[0];
517
517
  if (r !== o)
518
- throw new v(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${e.shape} and y shape = ${n.shape}`);
518
+ throw new F(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${e.shape} and y shape = ${n.shape}`);
519
519
  }
520
520
  if (e.rank === 2 && n.rank === 2)
521
521
  return de({
@@ -530,7 +530,7 @@ function ht(e, n, t, s) {
530
530
  const r = e.shape.slice(), o = r.pop();
531
531
  e = c(e, [-1, o]);
532
532
  const i = n.shape.slice(), f = i.pop(), a = i.pop(), u = [...i, f], m = Array.from({ length: n.rank }, (ne, I) => I === 0 ? n.rank - 2 : I <= n.rank - 2 ? I - 1 : I);
533
- n = c(Ve(n, m), [a, -1]);
533
+ n = c(Ue(n, m), [a, -1]);
534
534
  const d = [...r, ...u];
535
535
  return c(de({
536
536
  a: e,
@@ -576,7 +576,7 @@ function mt(e, n, t) {
576
576
  }
577
577
  function gt(e, n = 1) {
578
578
  if (n !== 1)
579
- throw new v(`Support for alpha values other than 1 (${n}) is not implemented yet.`);
579
+ throw new F(`Support for alpha values other than 1 (${n}) is not implemented yet.`);
580
580
  return Pe(e);
581
581
  }
582
582
  function kt(e) {
@@ -588,7 +588,7 @@ function $t(e, n, t, s) {
588
588
  function wt(e) {
589
589
  return S(() => {
590
590
  const n = M(0.5, Q(0.2, e));
591
- return Ce(n, 0, 1);
591
+ return ve(n, 0, 1);
592
592
  });
593
593
  }
594
594
  function At(e, n, t = !1) {
@@ -599,17 +599,17 @@ export {
599
599
  Ae as A,
600
600
  pt as B,
601
601
  tt as C,
602
- Cn as D,
602
+ vn as D,
603
603
  gt as E,
604
604
  wt as F,
605
605
  kt as G,
606
606
  nt as H,
607
607
  zn as I,
608
- qn as J,
608
+ Jn as J,
609
609
  mt as K,
610
610
  at as L,
611
611
  Zn as M,
612
- v as N,
612
+ F as N,
613
613
  on as O,
614
614
  Kn as P,
615
615
  Wn as Q,
@@ -629,13 +629,13 @@ export {
629
629
  _e as b,
630
630
  x as c,
631
631
  ht as d,
632
- Vn as e,
632
+ Un as e,
633
633
  Xn as f,
634
634
  Qn as g,
635
- Fn as h,
636
- vn as i,
635
+ Cn as h,
636
+ Fn as i,
637
637
  st as j,
638
- Jn as k,
638
+ qn as k,
639
639
  it as l,
640
640
  dt as m,
641
641
  Hn as n,
@@ -643,7 +643,7 @@ export {
643
643
  rt as p,
644
644
  jn as q,
645
645
  ft as r,
646
- Un as s,
646
+ Vn as s,
647
647
  Mn as t,
648
648
  Bn as u,
649
649
  xn as v,
@@ -1,4 +1,4 @@
1
- import { o as e, q as a, x as i, E as c, T as l } from "./index-CUXkjxiT.js";
1
+ import { o as e, n as a, v as i, E as c, T as l } from "./index-DSGwv2Yx.js";
2
2
  function u(r, t) {
3
3
  const n = a(r, "x", "tile", "string_or_numeric");
4
4
  i(n.rank === t.length, () => `Error in transpose: rank of input ${n.rank} must match length of reps ${t}.`);
@@ -2,6 +2,8 @@ import { Conversation, ITokeniser } from './type';
2
2
  import { default as EE } from 'eventemitter3';
3
3
  export declare const SPECIALS: string[];
4
4
  export default abstract class BaseTokeniser extends EE<'trainStatus'> implements ITokeniser {
5
+ id: string;
6
+ datasetID?: string;
5
7
  protected specialTokens: Map<string, number>;
6
8
  protected specialTokenSet: Set<number>;
7
9
  abstract vocabSize: number;
@@ -12,7 +14,8 @@ export default abstract class BaseTokeniser extends EE<'trainStatus'> implements
12
14
  isSpecialToken(index: number): boolean;
13
15
  protected addSpecialTokens(): void;
14
16
  protected addSpecialToken(token: string, index: number): void;
15
- abstract train(text: Conversation[][], cb?: (vocab: number) => void): Promise<number>;
17
+ protected generateID(): void;
18
+ abstract train(text: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
16
19
  abstract getVocab(): string[];
17
20
  abstract getMerges(): [string, string][];
18
21
  abstract destroy(): void;
@@ -1,5 +1,5 @@
1
1
  import { E as r } from "../index-DvYrXKkX.js";
2
- const h = [
2
+ const l = [
3
3
  "<eos>",
4
4
  "<bos>",
5
5
  "",
@@ -11,20 +11,36 @@ const h = [
11
11
  "<|system_start|>",
12
12
  "<|system_end|>"
13
13
  ];
14
- class k extends r {
14
+ class T extends r {
15
+ id = "untrained";
16
+ datasetID;
15
17
  specialTokens = /* @__PURE__ */ new Map();
16
18
  specialTokenSet = /* @__PURE__ */ new Set();
17
19
  isSpecialToken(s) {
18
20
  return this.specialTokenSet.has(s);
19
21
  }
20
22
  addSpecialTokens() {
21
- h.forEach((s, t) => {
23
+ l.forEach((s, t) => {
22
24
  this.addToken(s, t), this.specialTokens.set(s, t), this.specialTokenSet.add(t);
23
25
  });
24
26
  }
25
27
  addSpecialToken(s, t) {
26
28
  this.specialTokens.set(s, t), this.specialTokenSet.add(t);
27
29
  }
30
+ generateID() {
31
+ const s = this.getVocab();
32
+ let t = 2166136261, e = 2654435769;
33
+ for (let a = 0; a < s.length; a++) {
34
+ const i = s[a];
35
+ t ^= i.length, t = Math.imul(t, 16777619), e ^= a, e = Math.imul(e, 2246822507);
36
+ for (let c = 0; c < i.length; c++) {
37
+ const h = i.charCodeAt(c);
38
+ t ^= h, t = Math.imul(t, 16777619), e ^= h, e = Math.imul(e, 3266489909);
39
+ }
40
+ }
41
+ const o = (t >>> 0).toString(36), n = (e >>> 0).toString(36);
42
+ this.id = "tokeniser_" + o + "_" + n;
43
+ }
28
44
  encodeSequence(s) {
29
45
  const t = this.encode(s);
30
46
  return [this.bosToken, ...t, this.eosToken];
@@ -94,6 +110,6 @@ class k extends r {
94
110
  }
95
111
  }
96
112
  export {
97
- h as SPECIALS,
98
- k as default
113
+ l as SPECIALS,
114
+ T as default
99
115
  };
@@ -12,7 +12,7 @@ export default class CharTokeniser extends BaseTokeniser {
12
12
  addToken(token: string, index?: number): number;
13
13
  get trained(): boolean;
14
14
  destroy(): void;
15
- train(text: Conversation[][]): Promise<number>;
15
+ train(text: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
16
16
  tokenise(text: string[], numeric: true): number[][];
17
17
  tokenise(text: string[]): string[][];
18
18
  detokenise(tokens: (number[] | Uint16Array)[]): string[];
@@ -1,6 +1,7 @@
1
- import k, { SPECIALS as d } from "./BaseTokeniser.js";
2
- const u = ["<eos>", "<unk>"];
3
- class T extends k {
1
+ import { yieldIfNeeded as u } from "../utilities/yielder.js";
2
+ import b, { SPECIALS as T } from "./BaseTokeniser.js";
3
+ const l = ["<eos>", "<unk>"];
4
+ class x extends b {
4
5
  vocabSize = 0;
5
6
  eosToken = 0;
6
7
  bosToken = 0;
@@ -8,30 +9,30 @@ class T extends k {
8
9
  vocab = [];
9
10
  cache = /* @__PURE__ */ new Map();
10
11
  _trained = !1;
11
- constructor(i) {
12
- if (super(), Array.isArray(i)) {
13
- if (this.vocab = i, this.vocab.length > 0)
14
- this.vocabSize = this.vocab.length, d.forEach((t) => {
15
- const e = this.vocab.indexOf(t);
16
- e !== -1 && this.addSpecialToken(t, e);
17
- }), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex("") ?? -1, this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<unk>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<pad>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("_")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf(" ")), this.unkToken === -1 && (this.unkToken = this.eosToken), this.vocab = this.vocab.map((t) => t === "<pad>" ? "" : t), this.vocab.forEach((t, e) => {
18
- this.cache.set(t, e);
12
+ constructor(t) {
13
+ if (super(), Array.isArray(t)) {
14
+ if (this.vocab = t, this.vocab.length > 0)
15
+ this.vocabSize = this.vocab.length, T.forEach((i) => {
16
+ const e = this.vocab.indexOf(i);
17
+ e !== -1 && this.addSpecialToken(i, e);
18
+ }), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex("") ?? -1, this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<unk>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<pad>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("_")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf(" ")), this.unkToken === -1 && (this.unkToken = this.eosToken), this.vocab = this.vocab.map((i) => i === "<pad>" ? "" : i), this.vocab.forEach((i, e) => {
19
+ this.cache.set(i, e);
19
20
  });
20
21
  else
21
22
  throw new Error("Vocab cannot be empty");
22
23
  this._trained = !0;
23
24
  } else
24
- this.vocabSize = i, this.vocab = new Array(this.vocabSize).fill(""), this.addSpecialTokens(), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex(""), this.vocab.forEach((t, e) => {
25
- this.cache.set(t, e);
25
+ this.vocabSize = t, this.vocab = new Array(this.vocabSize).fill(""), this.addSpecialTokens(), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex(""), this.vocab.forEach((i, e) => {
26
+ this.cache.set(i, e);
26
27
  }), this.cache.set("", this.unkToken);
27
28
  }
28
- addToken(i, t) {
29
- if (this.cache.has(i))
30
- return this.cache.get(i);
29
+ addToken(t, i) {
30
+ if (this.cache.has(t))
31
+ return this.cache.get(t);
31
32
  let e;
32
- if (t !== void 0 ? e = t : (e = this.vocab.indexOf("", this.unkToken + 1), e === -1 && (e = this.vocabSize)), e >= this.vocabSize)
33
+ if (i !== void 0 ? e = i : (e = this.vocab.indexOf("", this.unkToken + 1), e === -1 && (e = this.vocabSize)), e >= this.vocabSize)
33
34
  throw new Error("Vocab size exceeded");
34
- return this.vocab[e] = i, this.cache.set(i, e), e;
35
+ return this.vocab[e] = t, this.cache.set(t, e), e;
35
36
  }
36
37
  get trained() {
37
38
  return this.vocab.length === this.vocabSize && this._trained;
@@ -39,43 +40,54 @@ class T extends k {
39
40
  destroy() {
40
41
  this.cache.clear(), this.vocab = [];
41
42
  }
42
- async train(i) {
43
- const t = i.map((o) => o.map((n) => n.content.split(""))).flat(2), e = new Set(t), s = Array.from(e), h = this.vocab.indexOf("", this.unkToken + 1), a = this.vocabSize - u.length;
44
- if (h === -1)
45
- return this.vocabSize;
46
- if (this._trained = !0, s.length > a) {
47
- const o = /* @__PURE__ */ new Map();
48
- t.forEach((n) => {
49
- o.set(n, (o.get(n) || 0) + 1);
50
- }), s.sort((n, r) => (o.get(n) || 0) - (o.get(r) || 0)), s.splice(0, s.length - a);
43
+ async train(t, i, e) {
44
+ this.datasetID = e;
45
+ const a = /* @__PURE__ */ new Set();
46
+ let h = performance.now();
47
+ for (const n of t)
48
+ n.forEach((o) => {
49
+ for (const r of o.content)
50
+ a.add(r);
51
+ }), h = await u(h, i, 0);
52
+ const s = Array.from(a), k = this.vocab.indexOf("", this.unkToken + 1), d = this.vocabSize - l.length;
53
+ if (k === -1)
54
+ return this.generateID(), this.vocabSize;
55
+ if (this._trained = !0, s.length > d) {
56
+ const n = /* @__PURE__ */ new Map();
57
+ t.forEach((o) => {
58
+ o.forEach((r) => {
59
+ for (const f of r.content)
60
+ n.set(f, (n.get(f) || 0) + 1);
61
+ });
62
+ }), s.sort((o, r) => (n.get(o) || 0) - (n.get(r) || 0)), s.splice(0, s.length - d);
51
63
  }
52
- let c = h;
64
+ let c = k;
53
65
  if (c !== -1) {
54
- const o = new Set(this.vocab);
55
- for (const n of s)
56
- if (!o.has(n) && (this.vocab[c] = n, o.add(n), c = this.vocab.indexOf("", c + 1), c === -1))
66
+ const n = new Set(this.vocab);
67
+ for (const o of s)
68
+ if (!n.has(o) && (this.vocab[c] = o, n.add(o), c = this.vocab.indexOf("", c + 1), c === -1))
57
69
  break;
58
70
  }
59
- return this.cache.clear(), this.vocab.forEach((o, n) => {
60
- this.cache.set(o, n);
61
- }), this.emit("trainStatus", "trained"), this.vocabSize;
71
+ return this.cache.clear(), this.vocab.forEach((n, o) => {
72
+ this.cache.set(n, o);
73
+ }), this.generateID(), this.emit("trainStatus", "trained"), this.vocabSize;
62
74
  }
63
- tokenise(i, t) {
75
+ tokenise(t, i) {
64
76
  if (!this.trained)
65
77
  throw new Error("Tokeniser not trained");
66
- return i.map((s) => t ? s.split("").map((h) => this.cache.get(h) ?? this.unkToken) : s.split("").map((h) => {
67
- const a = this.cache.get(h);
68
- return a !== void 0 ? this.vocab[a] : "";
78
+ return t.map((a) => i ? a.split("").map((h) => this.cache.get(h) ?? this.unkToken) : a.split("").map((h) => {
79
+ const s = this.cache.get(h);
80
+ return s !== void 0 ? this.vocab[s] : "";
69
81
  }));
70
82
  }
71
- detokenise(i) {
72
- return i.map((e) => Array.from(e).map((s) => this.vocab[s] || "").join(""));
83
+ detokenise(t) {
84
+ return t.map((e) => Array.from(e).map((a) => this.vocab[a] || "").join(""));
73
85
  }
74
- encode(i) {
75
- return this.tokenise([i], !0)[0];
86
+ encode(t) {
87
+ return this.tokenise([t], !0)[0];
76
88
  }
77
- decode(i) {
78
- return this.detokenise([i])[0];
89
+ decode(t) {
90
+ return this.detokenise([t])[0];
79
91
  }
80
92
  getVocab() {
81
93
  return this.vocab;
@@ -83,13 +95,13 @@ class T extends k {
83
95
  getMerges() {
84
96
  return [];
85
97
  }
86
- async createTrainingData(i, t = 5) {
87
- const e = await this.tokenise(i, !0), s = [], h = [];
88
- for (let a = 0; a < e.length - t; a++)
89
- s.push(...e[a].slice(0, t)), h.push(e[a + 1][0]);
90
- return [s, h];
98
+ async createTrainingData(t, i = 5) {
99
+ const e = await this.tokenise(t, !0), a = [], h = [];
100
+ for (let s = 0; s < e.length - i; s++)
101
+ a.push(...e[s].slice(0, i)), h.push(e[s + 1][0]);
102
+ return [a, h];
91
103
  }
92
104
  }
93
105
  export {
94
- T as default
106
+ x as default
95
107
  };
@@ -15,7 +15,7 @@ export default class BPETokeniser extends BaseTokeniser {
15
15
  get eosToken(): number;
16
16
  get bosToken(): number;
17
17
  get unkToken(): number;
18
- train(text?: Conversation[][], cb?: (vocab: number) => void): Promise<number>;
18
+ train(text?: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
19
19
  getVocab(): string[];
20
20
  getMerges(): [string, string][];
21
21
  private tokeniseWord;