@genai-fi/nanogpt 0.20.0 → 0.20.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (433) hide show
  1. package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
  2. package/dist/DatasetBuilder-DgURD85T.js +712 -0
  3. package/dist/Generator.d.ts +82 -0
  4. package/dist/Generator.js +2 -0
  5. package/dist/RealDiv-DBu0FQqT.js +362 -0
  6. package/dist/Reshape-CABOPB9d.js +94 -0
  7. package/dist/Reshape-DqO3r8BC.js +17 -0
  8. package/dist/TeachableLLM.d.ts +70 -0
  9. package/dist/TeachableLLM.js +2 -0
  10. package/dist/Trainer.d.ts +43 -0
  11. package/dist/Trainer.js +2 -0
  12. package/dist/backend.d.ts +2 -0
  13. package/dist/backend.js +13 -0
  14. package/dist/backend_util-Cg-roD1p.js +399 -0
  15. package/dist/binary_op_util-CrYk9LXL.js +103 -0
  16. package/dist/checks/appendCache.d.ts +1 -0
  17. package/dist/checks/appendCache.js +55 -0
  18. package/dist/checks/attentionMask.d.ts +1 -0
  19. package/dist/checks/attentionMask.js +56 -0
  20. package/dist/checks/check.d.ts +9 -0
  21. package/dist/checks/check.js +32 -0
  22. package/dist/checks/gelu.d.ts +1 -0
  23. package/dist/checks/gelu.js +46 -0
  24. package/dist/checks/index.d.ts +26 -0
  25. package/dist/checks/index.js +28 -0
  26. package/dist/checks/matMulGelu.d.ts +1 -0
  27. package/dist/checks/matMulGelu.js +84 -0
  28. package/dist/checks/normRMS.d.ts +1 -0
  29. package/dist/checks/normRMS.js +28 -0
  30. package/dist/checks/normRMSGrad.d.ts +1 -0
  31. package/dist/checks/normRMSGrad.js +22 -0
  32. package/dist/checks/packUnpack.d.ts +1 -0
  33. package/dist/checks/packUnpack.js +46 -0
  34. package/dist/checks/qkv.d.ts +1 -0
  35. package/dist/checks/qkv.js +34 -0
  36. package/dist/checks/rope.d.ts +1 -0
  37. package/dist/checks/rope.js +30 -0
  38. package/dist/checks/weights.d.ts +14 -0
  39. package/dist/checks/weights.js +27 -0
  40. package/dist/chunk-BPntVaq0.js +23 -0
  41. package/dist/complex_util-CkazZsaH.js +60 -0
  42. package/dist/concat_util-CWDZCBlA.js +19 -0
  43. package/dist/data/docx.d.ts +2 -0
  44. package/dist/data/docx.js +3046 -0
  45. package/dist/data/pdf.d.ts +2 -0
  46. package/dist/data/pdf.js +17 -0
  47. package/dist/data/textLoader.d.ts +7 -0
  48. package/dist/data/textLoader.js +613 -0
  49. package/dist/dist-BewPQWjc.js +7572 -0
  50. package/dist/dist-DVmq73nz.js +8775 -0
  51. package/dist/dist-DXwIvKxl.js +896 -0
  52. package/dist/dist-VEU5mfO0.js +7545 -0
  53. package/dist/gelu-Bf1HW1RY.js +27 -0
  54. package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
  55. package/dist/inference/types.d.ts +16 -0
  56. package/dist/inference/types.js +0 -0
  57. package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
  58. package/dist/layers/BaseLayer.d.ts +44 -0
  59. package/dist/layers/BaseLayer.js +76 -0
  60. package/dist/layers/CausalSelfAttention.d.ts +39 -0
  61. package/dist/layers/CausalSelfAttention.js +99 -0
  62. package/dist/layers/LoRA.d.ts +14 -0
  63. package/dist/layers/LoRA.js +48 -0
  64. package/dist/layers/MLP.d.ts +17 -0
  65. package/dist/layers/MLP.js +34 -0
  66. package/dist/layers/PositionEmbedding.d.ts +8 -0
  67. package/dist/layers/PositionEmbedding.js +27 -0
  68. package/dist/layers/RMSNorm.d.ts +12 -0
  69. package/dist/layers/RMSNorm.js +20 -0
  70. package/dist/layers/RoPECache.d.ts +18 -0
  71. package/dist/layers/RoPECache.js +337 -0
  72. package/dist/layers/TiedEmbedding.d.ts +13 -0
  73. package/dist/layers/TiedEmbedding.js +32 -0
  74. package/dist/layers/TransformerBlock.d.ts +27 -0
  75. package/dist/layers/TransformerBlock.js +51 -0
  76. package/dist/layers/WeightStore.d.ts +20 -0
  77. package/dist/layers/WeightStore.js +69 -0
  78. package/dist/loader/load.d.ts +6 -0
  79. package/dist/loader/load.js +2 -0
  80. package/dist/loader/loadHF.d.ts +8 -0
  81. package/dist/loader/loadHF.js +2 -0
  82. package/dist/loader/loadTransformers.d.ts +4 -0
  83. package/dist/loader/loadTransformers.js +2 -0
  84. package/dist/loader/loadZipMeta.d.ts +3 -0
  85. package/dist/loader/loadZipMeta.js +16 -0
  86. package/dist/loader/newZipLoad.d.ts +3 -0
  87. package/dist/loader/newZipLoad.js +2 -0
  88. package/dist/loader/oldZipLoad.d.ts +9 -0
  89. package/dist/loader/oldZipLoad.js +2 -0
  90. package/dist/loader/save.d.ts +16 -0
  91. package/dist/loader/save.js +2 -0
  92. package/dist/loader/types.d.ts +68 -0
  93. package/dist/loader/types.js +0 -0
  94. package/dist/main-D5CbfCiV.js +13500 -0
  95. package/dist/main.d.ts +50 -0
  96. package/dist/main.js +16 -0
  97. package/dist/matMul16-BNfZSnNM.js +81 -0
  98. package/dist/matMulGelu-CPTntosE.js +162 -0
  99. package/dist/models/NanoGPTV1.d.ts +16 -0
  100. package/dist/models/NanoGPTV1.js +2 -0
  101. package/dist/models/NanoGPTV2.d.ts +16 -0
  102. package/dist/models/NanoGPTV2.js +2 -0
  103. package/dist/models/config.d.ts +27 -0
  104. package/dist/models/config.js +37 -0
  105. package/dist/models/factory.d.ts +3 -0
  106. package/dist/models/factory.js +2 -0
  107. package/dist/models/model.d.ts +44 -0
  108. package/dist/models/model.js +2 -0
  109. package/dist/ops/adamAdjust.d.ts +2 -0
  110. package/dist/ops/adamAdjust.js +18 -0
  111. package/dist/ops/adamMoments.d.ts +2 -0
  112. package/dist/ops/adamMoments.js +16 -0
  113. package/dist/ops/add16.d.ts +2 -0
  114. package/dist/ops/add16.js +12 -0
  115. package/dist/ops/appendCache.d.ts +2 -0
  116. package/dist/ops/appendCache.js +25 -0
  117. package/dist/ops/attentionMask.d.ts +2 -0
  118. package/dist/ops/attentionMask.js +16 -0
  119. package/dist/ops/concat16.d.ts +2 -0
  120. package/dist/ops/concat16.js +8 -0
  121. package/dist/ops/cpu/adamAdjust.d.ts +1 -0
  122. package/dist/ops/cpu/adamAdjust.js +16 -0
  123. package/dist/ops/cpu/adamMoments.d.ts +1 -0
  124. package/dist/ops/cpu/adamMoments.js +16 -0
  125. package/dist/ops/cpu/appendCache.d.ts +1 -0
  126. package/dist/ops/cpu/appendCache.js +65 -0
  127. package/dist/ops/cpu/attentionMask.d.ts +1 -0
  128. package/dist/ops/cpu/attentionMask.js +16 -0
  129. package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
  130. package/dist/ops/cpu/fusedSoftmax.js +22 -0
  131. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  132. package/dist/ops/cpu/gatherSub.js +12 -0
  133. package/dist/ops/cpu/gelu.d.ts +1 -0
  134. package/dist/ops/cpu/gelu.js +36 -0
  135. package/dist/ops/cpu/matMul16.d.ts +1 -0
  136. package/dist/ops/cpu/matMul16.js +14 -0
  137. package/dist/ops/cpu/matMulGelu.d.ts +1 -0
  138. package/dist/ops/cpu/matMulGelu.js +41 -0
  139. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  140. package/dist/ops/cpu/matMulMul.js +20 -0
  141. package/dist/ops/cpu/mulDropout.d.ts +1 -0
  142. package/dist/ops/cpu/mulDropout.js +20 -0
  143. package/dist/ops/cpu/normRMS.d.ts +1 -0
  144. package/dist/ops/cpu/normRMS.js +35 -0
  145. package/dist/ops/cpu/qkv.d.ts +5 -0
  146. package/dist/ops/cpu/qkv.js +73 -0
  147. package/dist/ops/cpu/rope.d.ts +6 -0
  148. package/dist/ops/cpu/rope.js +81 -0
  149. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  150. package/dist/ops/cpu/scatterSub.js +12 -0
  151. package/dist/ops/dot16.d.ts +2 -0
  152. package/dist/ops/dot16.js +29 -0
  153. package/dist/ops/dropout.d.ts +2 -0
  154. package/dist/ops/dropout.js +11 -0
  155. package/dist/ops/dropout16.d.ts +2 -0
  156. package/dist/ops/dropout16.js +22 -0
  157. package/dist/ops/gatherSub.d.ts +2 -0
  158. package/dist/ops/gatherSub.js +13 -0
  159. package/dist/ops/gelu.d.ts +3 -0
  160. package/dist/ops/gelu.js +2 -0
  161. package/dist/ops/globalNorm.d.ts +2 -0
  162. package/dist/ops/globalNorm.js +19 -0
  163. package/dist/ops/grads/add16.d.ts +1 -0
  164. package/dist/ops/grads/add16.js +27 -0
  165. package/dist/ops/grads/attentionMask.d.ts +1 -0
  166. package/dist/ops/grads/attentionMask.js +26 -0
  167. package/dist/ops/grads/dropout16.d.ts +1 -0
  168. package/dist/ops/grads/dropout16.js +1 -0
  169. package/dist/ops/grads/gelu.d.ts +2 -0
  170. package/dist/ops/grads/gelu.js +2 -0
  171. package/dist/ops/grads/matMul16.d.ts +2 -0
  172. package/dist/ops/grads/matMul16.js +2 -0
  173. package/dist/ops/grads/matMulGelu.d.ts +1 -0
  174. package/dist/ops/grads/matMulGelu.js +22 -0
  175. package/dist/ops/grads/mul16.d.ts +1 -0
  176. package/dist/ops/grads/mul16.js +1 -0
  177. package/dist/ops/grads/normRMS.d.ts +3 -0
  178. package/dist/ops/grads/normRMS.js +37 -0
  179. package/dist/ops/grads/pack16.d.ts +2 -0
  180. package/dist/ops/grads/pack16.js +2 -0
  181. package/dist/ops/grads/qkv.d.ts +3 -0
  182. package/dist/ops/grads/qkv.js +46 -0
  183. package/dist/ops/grads/rope.d.ts +2 -0
  184. package/dist/ops/grads/rope.js +2 -0
  185. package/dist/ops/grads/softmax16.d.ts +2 -0
  186. package/dist/ops/grads/softmax16.js +23 -0
  187. package/dist/ops/grads/unpack16.d.ts +2 -0
  188. package/dist/ops/grads/unpack16.js +2 -0
  189. package/dist/ops/grads/utils.d.ts +4 -0
  190. package/dist/ops/grads/utils.js +12 -0
  191. package/dist/ops/log.d.ts +0 -0
  192. package/dist/ops/log.js +1 -0
  193. package/dist/ops/matMul16.d.ts +15 -0
  194. package/dist/ops/matMul16.js +2 -0
  195. package/dist/ops/matMulGelu.d.ts +3 -0
  196. package/dist/ops/matMulGelu.js +20 -0
  197. package/dist/ops/matMulMul.d.ts +2 -0
  198. package/dist/ops/matMulMul.js +16 -0
  199. package/dist/ops/mul16.d.ts +2 -0
  200. package/dist/ops/mul16.js +43 -0
  201. package/dist/ops/mulDrop.d.ts +2 -0
  202. package/dist/ops/mulDrop.js +15 -0
  203. package/dist/ops/normRMS.d.ts +2 -0
  204. package/dist/ops/normRMS.js +22 -0
  205. package/dist/ops/pack16.d.ts +2 -0
  206. package/dist/ops/pack16.js +2 -0
  207. package/dist/ops/qkv.d.ts +2 -0
  208. package/dist/ops/qkv.js +16 -0
  209. package/dist/ops/reshape16.d.ts +2 -0
  210. package/dist/ops/reshape16.js +33 -0
  211. package/dist/ops/rope.d.ts +3 -0
  212. package/dist/ops/rope.js +2 -0
  213. package/dist/ops/scatterSub.d.ts +2 -0
  214. package/dist/ops/scatterSub.js +13 -0
  215. package/dist/ops/slice16.d.ts +2 -0
  216. package/dist/ops/slice16.js +11 -0
  217. package/dist/ops/softmax16.d.ts +2 -0
  218. package/dist/ops/softmax16.js +9 -0
  219. package/dist/ops/sub16.d.ts +2 -0
  220. package/dist/ops/sub16.js +11 -0
  221. package/dist/ops/sum16.d.ts +2 -0
  222. package/dist/ops/sum16.js +13 -0
  223. package/dist/ops/transpose16.d.ts +3 -0
  224. package/dist/ops/transpose16.js +32 -0
  225. package/dist/ops/unpack16.d.ts +2 -0
  226. package/dist/ops/unpack16.js +2 -0
  227. package/dist/ops/webgl/adamAdjust.d.ts +1 -0
  228. package/dist/ops/webgl/adamAdjust.js +82 -0
  229. package/dist/ops/webgl/adamMoments.d.ts +1 -0
  230. package/dist/ops/webgl/adamMoments.js +44 -0
  231. package/dist/ops/webgl/appendCache.d.ts +1 -0
  232. package/dist/ops/webgl/appendCache.js +53 -0
  233. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  234. package/dist/ops/webgl/attentionMask.js +64 -0
  235. package/dist/ops/webgl/dropout16.d.ts +1 -0
  236. package/dist/ops/webgl/dropout16.js +12 -0
  237. package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
  238. package/dist/ops/webgl/fusedSoftmax.js +70 -0
  239. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  240. package/dist/ops/webgl/gatherSub.js +28 -0
  241. package/dist/ops/webgl/gelu.d.ts +2 -0
  242. package/dist/ops/webgl/gelu.js +48 -0
  243. package/dist/ops/webgl/log.d.ts +17 -0
  244. package/dist/ops/webgl/log.js +14 -0
  245. package/dist/ops/webgl/matMul16.d.ts +1 -0
  246. package/dist/ops/webgl/matMul16.js +37 -0
  247. package/dist/ops/webgl/matMulGelu.d.ts +21 -0
  248. package/dist/ops/webgl/matMulGelu.js +2 -0
  249. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  250. package/dist/ops/webgl/matMulMul.js +24 -0
  251. package/dist/ops/webgl/mulDropout.d.ts +1 -0
  252. package/dist/ops/webgl/mulDropout.js +32 -0
  253. package/dist/ops/webgl/normRMS.d.ts +1 -0
  254. package/dist/ops/webgl/normRMS.js +114 -0
  255. package/dist/ops/webgl/qkv.d.ts +1 -0
  256. package/dist/ops/webgl/qkv.js +54 -0
  257. package/dist/ops/webgl/rope.d.ts +1 -0
  258. package/dist/ops/webgl/rope.js +72 -0
  259. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  260. package/dist/ops/webgl/scatterSub.js +28 -0
  261. package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
  262. package/dist/ops/webgpu/adamAdjust.js +77 -0
  263. package/dist/ops/webgpu/adamMoments.d.ts +1 -0
  264. package/dist/ops/webgpu/adamMoments.js +76 -0
  265. package/dist/ops/webgpu/add16.d.ts +1 -0
  266. package/dist/ops/webgpu/add16.js +14 -0
  267. package/dist/ops/webgpu/appendCache.d.ts +1 -0
  268. package/dist/ops/webgpu/appendCache.js +130 -0
  269. package/dist/ops/webgpu/attentionMask.d.ts +1 -0
  270. package/dist/ops/webgpu/attentionMask.js +42 -0
  271. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  272. package/dist/ops/webgpu/attentionMask32_program.js +62 -0
  273. package/dist/ops/webgpu/clipScale.d.ts +1 -0
  274. package/dist/ops/webgpu/clipScale.js +45 -0
  275. package/dist/ops/webgpu/concat16.d.ts +19 -0
  276. package/dist/ops/webgpu/concat16.js +111 -0
  277. package/dist/ops/webgpu/dropout16.d.ts +1 -0
  278. package/dist/ops/webgpu/dropout16.js +59 -0
  279. package/dist/ops/webgpu/gatherSub.d.ts +1 -0
  280. package/dist/ops/webgpu/gatherSub.js +52 -0
  281. package/dist/ops/webgpu/gelu.d.ts +14 -0
  282. package/dist/ops/webgpu/gelu.js +147 -0
  283. package/dist/ops/webgpu/index.d.ts +0 -0
  284. package/dist/ops/webgpu/index.js +26 -0
  285. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  286. package/dist/ops/webgpu/matMul16.js +70 -0
  287. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  288. package/dist/ops/webgpu/matMul16_program.js +303 -0
  289. package/dist/ops/webgpu/mul16.d.ts +1 -0
  290. package/dist/ops/webgpu/mul16.js +14 -0
  291. package/dist/ops/webgpu/norm2.d.ts +1 -0
  292. package/dist/ops/webgpu/norm2.js +46 -0
  293. package/dist/ops/webgpu/normRMS.d.ts +1 -0
  294. package/dist/ops/webgpu/normRMS.js +26 -0
  295. package/dist/ops/webgpu/normRMS16_program.d.ts +10 -0
  296. package/dist/ops/webgpu/normRMS16_program.js +28 -0
  297. package/dist/ops/webgpu/normRMS32_program.d.ts +10 -0
  298. package/dist/ops/webgpu/normRMS32_program.js +28 -0
  299. package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
  300. package/dist/ops/webgpu/normRMSGrad.js +225 -0
  301. package/dist/ops/webgpu/pack16.d.ts +1 -0
  302. package/dist/ops/webgpu/pack16.js +21 -0
  303. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  304. package/dist/ops/webgpu/pack16_program.js +93 -0
  305. package/dist/ops/webgpu/qkv.d.ts +1 -0
  306. package/dist/ops/webgpu/qkv.js +64 -0
  307. package/dist/ops/webgpu/rope.d.ts +1 -0
  308. package/dist/ops/webgpu/rope.js +163 -0
  309. package/dist/ops/webgpu/scatterSub.d.ts +1 -0
  310. package/dist/ops/webgpu/scatterSub.js +53 -0
  311. package/dist/ops/webgpu/slice16.d.ts +7 -0
  312. package/dist/ops/webgpu/slice16.js +74 -0
  313. package/dist/ops/webgpu/softmax16.d.ts +17 -0
  314. package/dist/ops/webgpu/softmax16.js +18 -0
  315. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  316. package/dist/ops/webgpu/softmax16_program.js +89 -0
  317. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  318. package/dist/ops/webgpu/softmax16_subgroup_program.js +70 -0
  319. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  320. package/dist/ops/webgpu/softmax16grad.js +31 -0
  321. package/dist/ops/webgpu/sub16.d.ts +1 -0
  322. package/dist/ops/webgpu/sub16.js +14 -0
  323. package/dist/ops/webgpu/sum16.d.ts +1 -0
  324. package/dist/ops/webgpu/sum16.js +29 -0
  325. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  326. package/dist/ops/webgpu/transpose16.js +37 -0
  327. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  328. package/dist/ops/webgpu/transpose16_program.js +51 -0
  329. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  330. package/dist/ops/webgpu/transpose16_shared_program.js +79 -0
  331. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  332. package/dist/ops/webgpu/unpack16.js +60 -0
  333. package/dist/ops/webgpu/utils/binary_op.d.ts +35 -0
  334. package/dist/ops/webgpu/utils/binary_op.js +141 -0
  335. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  336. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  337. package/dist/ops/webgpu/utils/reductions.d.ts +43 -0
  338. package/dist/ops/webgpu/utils/reductions.js +263 -0
  339. package/dist/pack16-Ck-spx_F.js +39 -0
  340. package/dist/patches/webgpu_backend.d.ts +18 -0
  341. package/dist/patches/webgpu_backend.js +43 -0
  342. package/dist/patches/webgpu_base.d.ts +21 -0
  343. package/dist/patches/webgpu_base.js +22 -0
  344. package/dist/patches/webgpu_program.d.ts +36 -0
  345. package/dist/patches/webgpu_program.js +293 -0
  346. package/dist/pdf-UoDqCYzz.js +16726 -0
  347. package/dist/picomatch-3tUnMMbd.js +1063 -0
  348. package/dist/rope-CbeGlsV8.js +25 -0
  349. package/dist/selu_util-zkAx5doH.js +24 -0
  350. package/dist/shared-D1coEFea.js +1314 -0
  351. package/dist/shared-DOgWaqvL.js +5 -0
  352. package/dist/slice_util-Dgb3ANWI.js +208 -0
  353. package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
  354. package/dist/tokeniser/BaseTokeniser.d.ts +33 -0
  355. package/dist/tokeniser/BaseTokeniser.js +2 -0
  356. package/dist/tokeniser/CharTokeniser.d.ts +24 -0
  357. package/dist/tokeniser/CharTokeniser.js +92 -0
  358. package/dist/tokeniser/bpe.d.ts +28 -0
  359. package/dist/tokeniser/bpe.js +170 -0
  360. package/dist/tokeniser/messages.d.ts +61 -0
  361. package/dist/tokeniser/messages.js +0 -0
  362. package/dist/tokeniser/type.d.ts +34 -0
  363. package/dist/tokeniser/type.js +0 -0
  364. package/dist/training/AdamW.d.ts +36 -0
  365. package/dist/training/AdamW.js +128 -0
  366. package/dist/training/BasicTrainer.d.ts +63 -0
  367. package/dist/training/BasicTrainer.js +265 -0
  368. package/dist/training/DatasetBuilder.d.ts +26 -0
  369. package/dist/training/DatasetBuilder.js +2 -0
  370. package/dist/training/Evaluator.d.ts +19 -0
  371. package/dist/training/Evaluator.js +48 -0
  372. package/dist/training/LRScheduler.d.ts +12 -0
  373. package/dist/training/LRScheduler.js +38 -0
  374. package/dist/training/PreTrainer.d.ts +11 -0
  375. package/dist/training/PreTrainer.js +22 -0
  376. package/dist/training/SFTTrainer.d.ts +12 -0
  377. package/dist/training/SFTTrainer.js +24 -0
  378. package/dist/training/loss.d.ts +3 -0
  379. package/dist/training/loss.js +19 -0
  380. package/dist/training/orthoGrad.d.ts +2 -0
  381. package/dist/training/orthoGrad.js +10 -0
  382. package/dist/training/sparseCrossEntropy.d.ts +7 -0
  383. package/dist/training/sparseCrossEntropy.js +47 -0
  384. package/dist/training/tasks/ConversationTask.d.ts +18 -0
  385. package/dist/training/tasks/ConversationTask.js +38 -0
  386. package/dist/training/tasks/PretrainingTask.d.ts +17 -0
  387. package/dist/training/tasks/PretrainingTask.js +42 -0
  388. package/dist/training/tasks/StartSentenceTask.d.ts +18 -0
  389. package/dist/training/tasks/StartSentenceTask.js +45 -0
  390. package/dist/training/tasks/Task.d.ts +22 -0
  391. package/dist/training/tasks/Task.js +55 -0
  392. package/dist/training/tasks/splitter.d.ts +5 -0
  393. package/dist/training/tasks/splitter.js +18 -0
  394. package/dist/training/types.d.ts +78 -0
  395. package/dist/training/types.js +0 -0
  396. package/dist/training/validation.d.ts +17 -0
  397. package/dist/training/validation.js +2 -0
  398. package/dist/utilities/arrayClose.d.ts +1 -0
  399. package/dist/utilities/arrayClose.js +16 -0
  400. package/dist/utilities/datasetID.d.ts +2 -0
  401. package/dist/utilities/datasetID.js +18 -0
  402. package/dist/utilities/dummy.d.ts +9 -0
  403. package/dist/utilities/dummy.js +36 -0
  404. package/dist/utilities/multinomialCPU.d.ts +2 -0
  405. package/dist/utilities/multinomialCPU.js +9 -0
  406. package/dist/utilities/naming.d.ts +4 -0
  407. package/dist/utilities/naming.js +0 -0
  408. package/dist/utilities/packed.d.ts +4 -0
  409. package/dist/utilities/packed.js +13 -0
  410. package/dist/utilities/parameters.d.ts +11 -0
  411. package/dist/utilities/parameters.js +38 -0
  412. package/dist/utilities/performance.d.ts +2 -0
  413. package/dist/utilities/performance.js +16 -0
  414. package/dist/utilities/profile.d.ts +17 -0
  415. package/dist/utilities/profile.js +33 -0
  416. package/dist/utilities/safetensors.d.ts +3 -0
  417. package/dist/utilities/safetensors.js +53 -0
  418. package/dist/utilities/sentences.d.ts +5 -0
  419. package/dist/utilities/sentences.js +32 -0
  420. package/dist/utilities/tokenParse.d.ts +1 -0
  421. package/dist/utilities/tokenParse.js +17 -0
  422. package/dist/utilities/topP.d.ts +1 -0
  423. package/dist/utilities/topP.js +12 -0
  424. package/dist/utilities/waitForModel.d.ts +2 -0
  425. package/dist/utilities/waitForModel.js +12 -0
  426. package/dist/utilities/weights.d.ts +12 -0
  427. package/dist/utilities/weights.js +40 -0
  428. package/dist/utilities/yielder.d.ts +1 -0
  429. package/dist/utilities/yielder.js +7 -0
  430. package/dist/webgpu-Dt7BMzWz.js +525 -0
  431. package/dist/webgpu_program-WOyIVMlZ.js +392 -0
  432. package/dist/webgpu_util-B_F3SShA.js +106 -0
  433. package/package.json +1 -1
@@ -0,0 +1,82 @@
1
+ import { Conversation, ITokeniser } from './tokeniser/type';
2
+ import { default as EE } from 'eventemitter3';
3
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
+ import { GenerateOptions } from './inference/types';
5
+ export declare function isConversation(data: unknown): data is Conversation[];
6
+ export interface IGenerateOptions extends GenerateOptions {
7
+ maxLength?: number;
8
+ noCache?: boolean;
9
+ allowSpecial?: boolean;
10
+ nonConversational?: boolean;
11
+ continuation?: boolean;
12
+ }
13
+ export interface IGenerator extends EE<'start' | 'stop' | 'tokens' | 'reset'> {
14
+ generate(prompt: Conversation[], options?: IGenerateOptions): Promise<Conversation[]>;
15
+ generate(options?: IGenerateOptions): Promise<Conversation[]>;
16
+ step(prompt: Conversation[], options?: IGenerateOptions): Promise<Conversation[]>;
17
+ step(options?: IGenerateOptions): Promise<Conversation[]>;
18
+ stop(): void;
19
+ getConversation(): Conversation[];
20
+ getAttentionData(): number[][][][][];
21
+ getProbabilitiesData(): number[][][];
22
+ getEmbeddingsData(): {
23
+ name: string;
24
+ tensor: number[][];
25
+ }[][];
26
+ getTokens(): number[];
27
+ getLastLoss(): number | null;
28
+ getLastMultinomialRand(): number | null;
29
+ dispose(): void;
30
+ reset(): void;
31
+ }
32
+ /**
33
+ * Text generator using a NanoGPT model and a tokeniser.
34
+ * This uses the forward method of the model to generate text token by token, including options for temperature, top-k, and top-p sampling.
35
+ */
36
+ export default class Generator extends EE<'start' | 'stop' | 'tokens' | 'reset'> implements IGenerator {
37
+ private readonly model;
38
+ private readonly tokeniser;
39
+ private active;
40
+ private cache;
41
+ private initialPrompt;
42
+ private outputConversation;
43
+ private actualTokeniser;
44
+ private lastToken;
45
+ private attentionData;
46
+ private probabilitiesData;
47
+ private embeddingsData;
48
+ private tokens;
49
+ private lastLoss;
50
+ private lastMultinomialRand;
51
+ private jobQueue;
52
+ private processingJob;
53
+ private startTime;
54
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
55
+ private tokenisePrompt;
56
+ private processResponse;
57
+ /** Generate logits and select a token. */
58
+ private _generateToken;
59
+ /** Generate multiple tokens in a loop and produce text */
60
+ private _generate;
61
+ private resetCache;
62
+ reset(): void;
63
+ dispose(): void;
64
+ private initialise;
65
+ step(prompt: Conversation[], options?: IGenerateOptions): Promise<Conversation[]>;
66
+ step(options?: IGenerateOptions): Promise<Conversation[]>;
67
+ generate(prompt: Conversation[], options?: IGenerateOptions): Promise<Conversation[]>;
68
+ generate(options?: IGenerateOptions): Promise<Conversation[]>;
69
+ private startJob;
70
+ getQueueLength(): number;
71
+ stop(): void;
72
+ getConversation(): Conversation[];
73
+ getAttentionData(): number[][][][][];
74
+ getProbabilitiesData(): number[][][];
75
+ getEmbeddingsData(): {
76
+ name: string;
77
+ tensor: number[][];
78
+ }[][];
79
+ getTokens(): number[];
80
+ getLastLoss(): number | null;
81
+ getLastMultinomialRand(): number | null;
82
+ }
@@ -0,0 +1,2 @@
1
+ import { c as e, s as t } from "./main-D5CbfCiV.js";
2
+ export { t as default, e as isConversation };
@@ -0,0 +1,362 @@
1
+ import { Dn as e, En as t, Io as n, Ks as r, Ms as i, Si as a, Tn as o, nc as s, oc as c, wn as l, xn as u } from "./dist-BewPQWjc.js";
2
+ import { L as d } from "./backend_util-Cg-roD1p.js";
3
+ import { o as f } from "./gpgpu_math-DvLcCH6u.js";
4
+ import { J as p, b as m } from "./shared-DOgWaqvL.js";
5
+ import { S as h, n as g } from "./kernel_funcs_utils-HiXOOx3f.js";
6
+ import { t as _ } from "./Reshape-CABOPB9d.js";
7
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/mean_gpu.js
8
+ var v = class {
9
+ constructor(e, t) {
10
+ this.variableNames = ["x"];
11
+ let { windowSize: n, batchSize: i, inSize: a, outSize: o } = e;
12
+ this.outputShape = [i, o];
13
+ let s = Math.floor(n / 4) * 4, c = n % 4, l = "sumValue += dot(values, ones);";
14
+ if (t != null) {
15
+ let e = 1 / t;
16
+ l = `sumValue += dot(values * ${r(e) ? e.toPrecision(2) : e}, ones);`;
17
+ }
18
+ let u = "";
19
+ a % n > 0 && (u = `
20
+ if (inIdx < 0 || inIdx >= ${a}) {
21
+ return 0.0;
22
+ }
23
+ `), this.userCode = `
24
+ const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
25
+
26
+ float getValue(int batch, int inIdx) {
27
+ ${u}
28
+ return getX(batch, inIdx);
29
+ }
30
+
31
+ void main() {
32
+ ivec2 coords = getOutputCoords();
33
+ int batch = coords[0];
34
+ int outIdx = coords[1];
35
+ int inOffset = outIdx * ${n};
36
+
37
+ float sumValue = 0.0;
38
+
39
+ for (int i = 0; i < ${s}; i += 4) {
40
+ int inIdx = inOffset + i;
41
+ vec4 values = vec4(
42
+ getValue(batch, inIdx),
43
+ getValue(batch, inIdx + 1),
44
+ getValue(batch, inIdx + 2),
45
+ getValue(batch, inIdx + 3)
46
+ );
47
+
48
+ ${l}
49
+ }
50
+
51
+ int inIdx = inOffset + ${s};
52
+ if (${c === 1}) {
53
+ vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
54
+
55
+ ${l}
56
+ } else if (${c === 2}) {
57
+ vec4 values = vec4(
58
+ getValue(batch, inIdx),
59
+ getValue(batch, inIdx + 1), 0.0, 0.0);
60
+
61
+ ${l}
62
+ } else if (${c === 3}) {
63
+ vec4 values = vec4(
64
+ getValue(batch, inIdx),
65
+ getValue(batch, inIdx + 1),
66
+ getValue(batch, inIdx + 2), 0.0);
67
+
68
+ ${l}
69
+ }
70
+ setOutput(sumValue);
71
+ }
72
+ `;
73
+ }
74
+ }, y = class {
75
+ constructor(e, t) {
76
+ this.variableNames = ["x"];
77
+ let { windowSize: n, batchSize: r, inSize: i, outSize: a } = e;
78
+ this.outputShape = [r, a];
79
+ let o = "0.0", s = "";
80
+ t === "prod" ? o = "1.0" : t === "min" ? (o = "1.0 / 1e-20", s = "min") : t === "max" && (o = "-1.0 / 1e-20", s = "max");
81
+ let c = `${t}(${t}(${t}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;
82
+ t === "sum" ? c = "sumValue" : t === "prod" ? c = "prodValue" : t === "all" ? c = "allValue" : t === "any" && (c = "anyValue");
83
+ let l = Math.floor(n / 4) * 4, u = n % 4, d = `
84
+ if (${t === "sum"}) {
85
+ sumValue += dot(values, ones);
86
+ } else if (${t === "prod"}) {
87
+ vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
88
+ prodValue *= tmp[0] * tmp[1];
89
+ } else {
90
+ minMaxValue = ${s}(values, minMaxValue);
91
+ if (${t === "min"} || ${t === "max"}) {
92
+ minMaxValue = ${s}(values, minMaxValue);
93
+ bvec4 isNaN = isnan(values);
94
+ if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
95
+ minMaxValue = vec4(NAN);
96
+ }
97
+ }
98
+ }
99
+ `, f = "vec4";
100
+ t === "all" ? (o = "1.0", d = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ", f = "bvec4") : t === "any" && (o = "0.0", d = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ", f = "bvec4");
101
+ let p = "";
102
+ i % n > 0 && (p = `
103
+ if (inIdx < 0 || inIdx >= ${i}) {
104
+ return initializationValue;
105
+ }
106
+ `), this.userCode = `
107
+ const float initializationValue = ${o};
108
+ const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
109
+
110
+ float getValue(int batch, int inIdx) {
111
+ ${p}
112
+ return getX(batch, inIdx);
113
+ }
114
+
115
+ void main() {
116
+ ivec2 coords = getOutputCoords();
117
+ int batch = coords[0];
118
+ int outIdx = coords[1];
119
+ int inOffset = outIdx * ${n};
120
+
121
+ vec4 minMaxValue = vec4(${o});
122
+ float prodValue = 1.0;
123
+ float sumValue = 0.0;
124
+ float allValue = 1.0;
125
+ float anyValue = 0.0;
126
+
127
+ for (int i = 0; i < ${l}; i += 4) {
128
+ int inIdx = inOffset + i;
129
+ ${f} values = ${f}(
130
+ getValue(batch, inIdx),
131
+ getValue(batch, inIdx + 1),
132
+ getValue(batch, inIdx + 2),
133
+ getValue(batch, inIdx + 3)
134
+ );
135
+
136
+ ${d}
137
+ }
138
+
139
+ int inIdx = inOffset + ${l};
140
+ if (${u === 1}) {
141
+ ${f} values = ${f}(
142
+ getValue(batch, inIdx),
143
+ initializationValue,
144
+ initializationValue,
145
+ initializationValue
146
+ );
147
+
148
+ ${d}
149
+ } else if (${u === 2}) {
150
+ ${f} values = ${f}(
151
+ getValue(batch, inIdx),
152
+ getValue(batch, inIdx + 1),
153
+ initializationValue,
154
+ initializationValue
155
+ );
156
+
157
+ ${d}
158
+ } else if (${u === 3}) {
159
+ ${f} values = ${f}(
160
+ getValue(batch, inIdx),
161
+ getValue(batch, inIdx + 1),
162
+ getValue(batch, inIdx + 2),
163
+ initializationValue
164
+ );
165
+
166
+ ${d}
167
+ }
168
+ setOutput(${c});
169
+ }
170
+ `;
171
+ }
172
+ };
173
+ //#endregion
174
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernel_utils/reduce.js
175
+ function b(e) {
176
+ let t = [];
177
+ for (; t.length === 0 || t[t.length - 1].outSize !== 1;) {
178
+ let n = t.length ? t[t.length - 1].outSize : e[1], r = d(n);
179
+ t.push({
180
+ inSize: n,
181
+ windowSize: r,
182
+ outSize: Math.ceil(n / r)
183
+ });
184
+ }
185
+ return t;
186
+ }
187
+ function x(e, t, n, r) {
188
+ let i = b(e.shape), a = e;
189
+ for (let o = 0; o < i.length; o++) {
190
+ let { inSize: s, windowSize: c, outSize: l } = i[o], u, d;
191
+ u = n === "mean" ? o === 0 ? new v({
192
+ windowSize: c,
193
+ inSize: s,
194
+ batchSize: e.shape[0],
195
+ outSize: l
196
+ }, s) : new v({
197
+ windowSize: c,
198
+ inSize: s,
199
+ batchSize: e.shape[0],
200
+ outSize: l
201
+ }) : new y({
202
+ windowSize: c,
203
+ inSize: s,
204
+ batchSize: e.shape[0],
205
+ outSize: l
206
+ }, n), d = a, a = r.runWebGLProgram(u, [a], t), d.dataId !== e.dataId && r.disposeIntermediateTensorInfo(d);
207
+ }
208
+ return a;
209
+ }
210
+ //#endregion
211
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/transpose_gpu.js
212
+ var S = class {
213
+ constructor(e, t) {
214
+ this.variableNames = ["A"];
215
+ let n = Array(e.length);
216
+ for (let r = 0; r < n.length; r++) n[r] = e[t[r]];
217
+ this.outputShape = n, this.rank = n.length;
218
+ let r = f(this.rank), i = C(t);
219
+ this.userCode = `
220
+ void main() {
221
+ ${r} resRC = getOutputCoords();
222
+ setOutput(getA(${i}));
223
+ }
224
+ `;
225
+ }
226
+ };
227
+ function C(e) {
228
+ let t = e.length;
229
+ if (t > 6) throw Error(`Transpose for rank ${t} is not yet supported`);
230
+ let n = [
231
+ "resRC.x",
232
+ "resRC.y",
233
+ "resRC.z",
234
+ "resRC.w",
235
+ "resRC.u",
236
+ "resRC.v"
237
+ ], r = Array(t);
238
+ for (let t = 0; t < e.length; t++) r[e[t]] = n[t];
239
+ return r.join();
240
+ }
241
+ //#endregion
242
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/transpose_packed_gpu.js
243
+ var w = class {
244
+ constructor(e, t) {
245
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
246
+ let n = Array(e.length);
247
+ for (let r = 0; r < n.length; r++) n[r] = e[t[r]];
248
+ if (this.outputShape = n, this.rank = n.length, this.rank > 6) throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
249
+ let r = f(this.rank), i = h("rc", this.rank), a = Array(this.rank);
250
+ for (let e = 0; e < t.length; e++) a[t[e]] = i[e];
251
+ let o = `vec2(${a.slice(-2).join()})`, s = `++${i[this.rank - 1]} < ${n[this.rank - 1]}`, c = `getChannel(getA(${a.join()}), ${o})`;
252
+ this.userCode = `
253
+ void main() {
254
+ ${r} rc = getOutputCoords();
255
+ vec4 result = vec4(0.);
256
+ result[0] = ${c};
257
+ if(${s}) {
258
+ result[1] = ${c};
259
+ }
260
+ --${i[this.rank - 1]};
261
+ if(++${i[this.rank - 2]} < ${n[this.rank - 2]}) {
262
+ result[2] = ${c};
263
+ if(${s}) {
264
+ result[3] = ${c};
265
+ }
266
+ }
267
+ setOutput(result);
268
+ }
269
+ `;
270
+ }
271
+ };
272
+ //#endregion
273
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Transpose_impl.js
274
+ function T(e, t, n) {
275
+ let r = i().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new w(e.shape, t) : new S(e.shape, t);
276
+ return n.runWebGLProgram(r, [e], e.dtype);
277
+ }
278
+ //#endregion
279
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Sum_impl.js
280
+ function E(n, r, i, d) {
281
+ let f = r, p = n.shape.length, m = s(f, n.shape), h = m, g = t(h, p), v = g != null, y = n;
282
+ v && (y = T(n, g, d), h = e(h.length, p)), u("sum", h, p);
283
+ let [b, S] = l(y.shape, h), C = b;
284
+ i && (C = o(b, m));
285
+ let w = c(S), E = c(n.shape) / w, D = _({
286
+ inputs: { x: y },
287
+ attrs: { shape: [E, w] },
288
+ backend: d
289
+ }), O = x(D, a(n.dtype), "sum", d), k = _({
290
+ inputs: { x: O },
291
+ attrs: { shape: C },
292
+ backend: d
293
+ });
294
+ return d.disposeIntermediateTensorInfo(D), d.disposeIntermediateTensorInfo(O), v && d.disposeIntermediateTensorInfo(y), k;
295
+ }
296
+ //#endregion
297
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Sum.js
298
+ function D(e) {
299
+ let { inputs: t, backend: n, attrs: r } = e, { x: i } = t, { axis: a, keepDims: o } = r;
300
+ return E(i, a, o, n);
301
+ }
302
+ var O = {
303
+ kernelName: "Sum",
304
+ backendName: "webgl",
305
+ kernelFunc: D
306
+ };
307
+ //#endregion
308
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Max_impl.js
309
+ function k(e, t, n, r) {
310
+ let i = c(t), a = c(e.shape) / i, o = _({
311
+ inputs: { x: e },
312
+ attrs: { shape: [a, i] },
313
+ backend: r
314
+ }), s = x(o, e.dtype, "max", r), l = _({
315
+ inputs: { x: s },
316
+ attrs: { shape: n },
317
+ backend: r
318
+ });
319
+ return r.disposeIntermediateTensorInfo(o), r.disposeIntermediateTensorInfo(s), l;
320
+ }
321
+ //#endregion
322
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Max.js
323
+ function A(n) {
324
+ let { inputs: r, backend: i, attrs: a } = n, { x: d } = r, { reductionIndices: f, keepDims: h } = a, g = d.shape.length, _ = s(f, d.shape), v = _, y = t(v, g), b = y != null, x = i.shouldExecuteOnCPU([d]), S = d;
325
+ if (b) {
326
+ if (x) {
327
+ let e = i.texData.get(S.dataId).values, t = Array(g);
328
+ for (let e = 0; e < t.length; e++) t[e] = d.shape[y[e]];
329
+ let n = p(e, d.shape, d.dtype, y, t);
330
+ S = i.makeTensorInfo(t, d.dtype);
331
+ let r = i.texData.get(S.dataId);
332
+ r.values = n;
333
+ } else S = T(d, y, i);
334
+ v = e(v.length, g);
335
+ }
336
+ u("max", v, g);
337
+ let [C, w] = l(S.shape, v), E = C;
338
+ h && (E = o(C, _));
339
+ let D;
340
+ if (x) {
341
+ let e = i.texData.get(S.dataId).values, t = m(e, c(w), E, d.dtype);
342
+ D = i.makeTensorInfo(E, d.dtype);
343
+ let n = i.texData.get(D.dataId);
344
+ n.values = t;
345
+ } else D = k(S, w, E, i);
346
+ return b && i.disposeIntermediateTensorInfo(S), D;
347
+ }
348
+ var j = {
349
+ kernelName: "Max",
350
+ backendName: "webgl",
351
+ kernelFunc: A
352
+ }, M = g({
353
+ opSnippet: "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;",
354
+ packedOpSnippet: "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n",
355
+ checkOutOfBounds: !0
356
+ }), N = {
357
+ kernelName: n,
358
+ backendName: "webgl",
359
+ kernelFunc: M
360
+ };
361
+ //#endregion
362
+ export { D as a, x as c, j as i, N as n, O as o, A as r, T as s, M as t };
@@ -0,0 +1,94 @@
1
+ import { Bo as e, Gs as t, Ps as n, oc as r } from "./dist-BewPQWjc.js";
2
+ import { E as i, a, c as o, d as s, j as c, l, u, z as d } from "./gpgpu_math-DvLcCH6u.js";
3
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/reshape_packed_gpu.js
4
+ var f = class {
5
+ constructor(e, t) {
6
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{
7
+ name: "inputShape",
8
+ type: "ivec3"
9
+ }], this.outputShape = e, this.enableShapeUniforms = a(this.outputShape.length);
10
+ let n = "";
11
+ for (let e = 0; e < 4; e++) {
12
+ let t = "thisRC = rc;";
13
+ e % 2 == 1 && (t += "thisRC.z += 1;"), e > 1 && (t += "thisRC.y += 1;"), n += `
14
+ ${t}
15
+ ${e > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : ""}
16
+ int flatIndex = getFlatIndex(thisRC);
17
+
18
+ ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
19
+ vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
20
+
21
+ result[${e}] =
22
+ getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
23
+ ${e > 0 ? "}" : ""}
24
+ `;
25
+ }
26
+ this.userCode = `
27
+ ${p(t, this.enableShapeUniforms)}
28
+ ${this.enableShapeUniforms ? l() : o(e)}
29
+
30
+ void main() {
31
+ ivec3 rc = getOutputCoords();
32
+
33
+ vec4 result = vec4(0.);
34
+
35
+ ivec3 thisRC;
36
+ int rows = ${this.enableShapeUniforms ? "outShape[1]" : e[1]};
37
+ int cols = ${this.enableShapeUniforms ? "outShape[2]" : e[2]};
38
+
39
+ ${n}
40
+
41
+ setOutput(result);
42
+ }
43
+ `;
44
+ }
45
+ };
46
+ function p(e, t) {
47
+ return `
48
+ ivec3 inputCoordsFromReshapedOutCoords(int index) {
49
+ ${t ? s([
50
+ "r",
51
+ "c",
52
+ "d"
53
+ ], "inputShape") : u([
54
+ "r",
55
+ "c",
56
+ "d"
57
+ ], e)}
58
+ return ivec3(r, c, d);
59
+ }
60
+ `;
61
+ }
62
+ //#endregion
63
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernel_utils/reshape.js
64
+ function m(e, t, n) {
65
+ let r = [i(e.shape), ...c(e.shape)], a = {
66
+ dtype: e.dtype,
67
+ shape: r,
68
+ dataId: e.dataId
69
+ }, o = new f([i(t), ...c(t)], r), s = [r], l = n.runWebGLProgram(o, [a], e.dtype, s, !0);
70
+ return {
71
+ dataId: l.dataId,
72
+ shape: t,
73
+ dtype: l.dtype
74
+ };
75
+ }
76
+ //#endregion
77
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Reshape.js
78
+ function h(e) {
79
+ let { inputs: i, backend: a, attrs: o } = e, { x: s } = i, { shape: c } = o, l = a, u = r(s.shape), f = t(c, u), p = r(f);
80
+ n(u === p, () => `The new shape (${f}) has ${p} elements and the old shape (${s.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
81
+ let h = l.texData.get(s.dataId);
82
+ return h.isPacked && !d(s.shape, f) && !(h.texture !== null && d(h.shape, f)) ? m(s, f, l) : (l.incRef(s.dataId), {
83
+ dataId: s.dataId,
84
+ shape: f,
85
+ dtype: s.dtype
86
+ });
87
+ }
88
+ var g = {
89
+ kernelName: e,
90
+ backendName: "webgl",
91
+ kernelFunc: h
92
+ };
93
+ //#endregion
94
+ export { g as n, f as r, h as t };
@@ -0,0 +1,17 @@
1
+ import { Bo as e, Gs as t, Ps as n, oc as r } from "./dist-BewPQWjc.js";
2
+ //#region node_modules/@tensorflow/tfjs-backend-webgpu/dist/kernels/Reshape.js
3
+ function i(e) {
4
+ let { inputs: i, attrs: a } = e, { x: o } = i, { shape: s } = a, c = r(o.shape), l = t(s, c), u = r(l);
5
+ return n(c === u, () => `The new shape (${l}) has ${u} elements and the old shape (${o.shape}) has ${c} elements. The new shape and old shape must have the same number of elements.`), e.backend.incRef(o.dataId), {
6
+ dataId: o.dataId,
7
+ shape: l,
8
+ dtype: o.dtype
9
+ };
10
+ }
11
+ var a = {
12
+ kernelName: e,
13
+ backendName: "webgpu",
14
+ kernelFunc: i
15
+ };
16
+ //#endregion
17
+ export { a as n, i as t };
@@ -0,0 +1,70 @@
1
+ import { GPTConfig, LoRAConfig } from './models/config';
2
+ import { Conversation, ITokeniser } from './tokeniser/type';
3
+ import { SaveOptions } from './loader/save';
4
+ import { LoadModelOptions } from './loader/load';
5
+ import { IGenerateOptions, IGenerator } from './Generator';
6
+ import { default as Trainer, TrainingType } from './Trainer';
7
+ import { default as MemoryProfiler } from './utilities/profile';
8
+ import { default as Model, ModelForwardAttributes } from './models/model';
9
+ import { Task } from './training/tasks/Task';
10
+ import { TrainingLogEntry, TrainingOptions } from './training/types';
11
+ import { ModelMode, TransformersMetadata } from './loader/types';
12
+ type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
13
+ export default class TeachableLLM {
14
+ private ee;
15
+ private _config?;
16
+ private _model?;
17
+ private _tokeniser?;
18
+ private _status;
19
+ private _memoryRequirements?;
20
+ meta: TransformersMetadata;
21
+ private _trainer;
22
+ constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes, GPTConfig>);
23
+ get currentTrainer(): Trainer | null;
24
+ get vocab(): string[];
25
+ get mode(): ModelMode;
26
+ set mode(mode: ModelMode);
27
+ /** Model is fully loaded */
28
+ get loaded(): boolean;
29
+ get config(): GPTConfig;
30
+ get model(): Model<ModelForwardAttributes, GPTConfig>;
31
+ get tokeniser(): ITokeniser;
32
+ get status(): TeachableLLMStatus;
33
+ /** Model is both ready and not busy */
34
+ get ready(): boolean;
35
+ get busy(): boolean;
36
+ createLoRA(name: string, loraConfig: LoRAConfig): void;
37
+ deleteLoRA(name: string): void;
38
+ renameLoRA(oldName: string, newName: string): void;
39
+ attachLoRA(name: string): void;
40
+ detachLoRA(): void;
41
+ hasLoRA(name?: string): boolean;
42
+ listLoRAs(): string[];
43
+ estimateTrainingMemoryUsage(batchSize: number): number;
44
+ private setStatus;
45
+ saveModel(options?: SaveOptions): Promise<Blob>;
46
+ static loadModel(data: Blob | Buffer | string, options?: LoadModelOptions): TeachableLLM;
47
+ static create(tokeniserType: 'char' | 'bpe', config: GPTConfig): TeachableLLM;
48
+ getProfiler(): MemoryProfiler | undefined;
49
+ get enableProfiler(): boolean;
50
+ set enableProfiler(value: boolean);
51
+ getNumParams(): number;
52
+ trainer(trainingType?: TrainingType, options?: TrainingOptions): Trainer;
53
+ train(text: Task[], options?: TrainingOptions, trainingType?: TrainingType): Promise<void>;
54
+ trainTokeniser(text: Conversation[][]): Promise<number>;
55
+ generator(): IGenerator;
56
+ generateText(prompt: Conversation[], options?: IGenerateOptions): Promise<Conversation[]>;
57
+ generateText(options?: IGenerateOptions): Promise<Conversation[]>;
58
+ dispose(): void;
59
+ on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
60
+ on(event: 'mode', listener: (mode: ModelMode) => void): void;
61
+ on(event: 'error', listener: (error: Error) => void): void;
62
+ on(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
63
+ on(event: 'loaded' | 'changeLoRA', listener: () => void): void;
64
+ off(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
65
+ off(event: 'mode', listener: (mode: ModelMode) => void): void;
66
+ off(event: 'error', listener: (error: Error) => void): void;
67
+ off(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
68
+ off(event: 'loaded' | 'changeLoRA', listener: () => void): void;
69
+ }
70
+ export {};
@@ -0,0 +1,2 @@
1
+ import { i as e } from "./main-D5CbfCiV.js";
2
+ export { e as default };
@@ -0,0 +1,43 @@
1
+ import { ITokeniser } from './tokeniser/type';
2
+ import { default as EE } from 'eventemitter3';
3
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
+ import { Task } from './training/tasks/Task';
5
+ import { TrainingOptions, TrainingLogEntry } from './training/types';
6
+ import { AdamWOptimizer } from './training/AdamW';
7
+ import { DatasetMetadata } from './loader/types';
8
+ interface TrainingProgress {
9
+ lastLog: TrainingLogEntry;
10
+ progress: number;
11
+ remaining: number;
12
+ }
13
+ export type TrainingType = 'pretraining' | 'sft';
14
+ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
15
+ private trainer;
16
+ readonly trainingType: TrainingType;
17
+ private hasTrained;
18
+ private trainDataset?;
19
+ private validationDataset?;
20
+ private totalTokens;
21
+ private tokensProcessed;
22
+ log: TrainingLogEntry[];
23
+ private progress;
24
+ options: TrainingOptions;
25
+ protected tokenizer: ITokeniser;
26
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser, trainingType?: TrainingType, options?: TrainingOptions, optimizer?: AdamWOptimizer);
27
+ constructor(trainer: Trainer, options?: TrainingOptions);
28
+ get model(): Model<ModelForwardAttributes>;
29
+ get optimizer(): AdamWOptimizer;
30
+ stop(): void;
31
+ reset(): void;
32
+ dispose(): void;
33
+ getTotalTokens(): number;
34
+ setOptions(options: TrainingOptions): void;
35
+ prepare(tasks?: Task[] | Uint16Array, datasets?: DatasetMetadata[]): Promise<void>;
36
+ private configureModel;
37
+ train(): Promise<void>;
38
+ step(options?: TrainingOptions): Promise<void>;
39
+ getLog(): TrainingLogEntry[];
40
+ getProgress(): TrainingProgress | null;
41
+ isPrepared(): boolean;
42
+ }
43
+ export {};
@@ -0,0 +1,2 @@
1
+ import { a as e } from "./main-D5CbfCiV.js";
2
+ export { e as default };
@@ -0,0 +1,2 @@
1
+ import { GPUOptions } from './patches/webgpu_base';
2
+ export declare function selectBackend(backendName: 'cpu' | 'webgl' | 'webgpu', options?: GPUOptions): Promise<void>;