@genai-fi/nanogpt 0.20.0 → 0.20.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (433) hide show
  1. package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
  2. package/dist/DatasetBuilder-DgURD85T.js +712 -0
  3. package/dist/Generator.d.ts +82 -0
  4. package/dist/Generator.js +2 -0
  5. package/dist/RealDiv-DBu0FQqT.js +362 -0
  6. package/dist/Reshape-CABOPB9d.js +94 -0
  7. package/dist/Reshape-DqO3r8BC.js +17 -0
  8. package/dist/TeachableLLM.d.ts +70 -0
  9. package/dist/TeachableLLM.js +2 -0
  10. package/dist/Trainer.d.ts +43 -0
  11. package/dist/Trainer.js +2 -0
  12. package/dist/backend.d.ts +2 -0
  13. package/dist/backend.js +13 -0
  14. package/dist/backend_util-Cg-roD1p.js +399 -0
  15. package/dist/binary_op_util-CrYk9LXL.js +103 -0
  16. package/dist/checks/appendCache.d.ts +1 -0
  17. package/dist/checks/appendCache.js +55 -0
  18. package/dist/checks/attentionMask.d.ts +1 -0
  19. package/dist/checks/attentionMask.js +56 -0
  20. package/dist/checks/check.d.ts +9 -0
  21. package/dist/checks/check.js +32 -0
  22. package/dist/checks/gelu.d.ts +1 -0
  23. package/dist/checks/gelu.js +46 -0
  24. package/dist/checks/index.d.ts +26 -0
  25. package/dist/checks/index.js +28 -0
  26. package/dist/checks/matMulGelu.d.ts +1 -0
  27. package/dist/checks/matMulGelu.js +84 -0
  28. package/dist/checks/normRMS.d.ts +1 -0
  29. package/dist/checks/normRMS.js +28 -0
  30. package/dist/checks/normRMSGrad.d.ts +1 -0
  31. package/dist/checks/normRMSGrad.js +22 -0
  32. package/dist/checks/packUnpack.d.ts +1 -0
  33. package/dist/checks/packUnpack.js +46 -0
  34. package/dist/checks/qkv.d.ts +1 -0
  35. package/dist/checks/qkv.js +34 -0
  36. package/dist/checks/rope.d.ts +1 -0
  37. package/dist/checks/rope.js +30 -0
  38. package/dist/checks/weights.d.ts +14 -0
  39. package/dist/checks/weights.js +27 -0
  40. package/dist/chunk-BPntVaq0.js +23 -0
  41. package/dist/complex_util-CkazZsaH.js +60 -0
  42. package/dist/concat_util-CWDZCBlA.js +19 -0
  43. package/dist/data/docx.d.ts +2 -0
  44. package/dist/data/docx.js +3046 -0
  45. package/dist/data/pdf.d.ts +2 -0
  46. package/dist/data/pdf.js +17 -0
  47. package/dist/data/textLoader.d.ts +7 -0
  48. package/dist/data/textLoader.js +613 -0
  49. package/dist/dist-BewPQWjc.js +7572 -0
  50. package/dist/dist-DVmq73nz.js +8775 -0
  51. package/dist/dist-DXwIvKxl.js +896 -0
  52. package/dist/dist-VEU5mfO0.js +7545 -0
  53. package/dist/gelu-Bf1HW1RY.js +27 -0
  54. package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
  55. package/dist/inference/types.d.ts +16 -0
  56. package/dist/inference/types.js +0 -0
  57. package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
  58. package/dist/layers/BaseLayer.d.ts +44 -0
  59. package/dist/layers/BaseLayer.js +76 -0
  60. package/dist/layers/CausalSelfAttention.d.ts +39 -0
  61. package/dist/layers/CausalSelfAttention.js +99 -0
  62. package/dist/layers/LoRA.d.ts +14 -0
  63. package/dist/layers/LoRA.js +48 -0
  64. package/dist/layers/MLP.d.ts +17 -0
  65. package/dist/layers/MLP.js +34 -0
  66. package/dist/layers/PositionEmbedding.d.ts +8 -0
  67. package/dist/layers/PositionEmbedding.js +27 -0
  68. package/dist/layers/RMSNorm.d.ts +12 -0
  69. package/dist/layers/RMSNorm.js +20 -0
  70. package/dist/layers/RoPECache.d.ts +18 -0
  71. package/dist/layers/RoPECache.js +337 -0
  72. package/dist/layers/TiedEmbedding.d.ts +13 -0
  73. package/dist/layers/TiedEmbedding.js +32 -0
  74. package/dist/layers/TransformerBlock.d.ts +27 -0
  75. package/dist/layers/TransformerBlock.js +51 -0
  76. package/dist/layers/WeightStore.d.ts +20 -0
  77. package/dist/layers/WeightStore.js +69 -0
  78. package/dist/loader/load.d.ts +6 -0
  79. package/dist/loader/load.js +2 -0
  80. package/dist/loader/loadHF.d.ts +8 -0
  81. package/dist/loader/loadHF.js +2 -0
  82. package/dist/loader/loadTransformers.d.ts +4 -0
  83. package/dist/loader/loadTransformers.js +2 -0
  84. package/dist/loader/loadZipMeta.d.ts +3 -0
  85. package/dist/loader/loadZipMeta.js +16 -0
  86. package/dist/loader/newZipLoad.d.ts +3 -0
  87. package/dist/loader/newZipLoad.js +2 -0
  88. package/dist/loader/oldZipLoad.d.ts +9 -0
  89. package/dist/loader/oldZipLoad.js +2 -0
  90. package/dist/loader/save.d.ts +16 -0
  91. package/dist/loader/save.js +2 -0
  92. package/dist/loader/types.d.ts +68 -0
  93. package/dist/loader/types.js +0 -0
  94. package/dist/main-D5CbfCiV.js +13500 -0
  95. package/dist/main.d.ts +50 -0
  96. package/dist/main.js +16 -0
  97. package/dist/matMul16-BNfZSnNM.js +81 -0
  98. package/dist/matMulGelu-CPTntosE.js +162 -0
  99. package/dist/models/NanoGPTV1.d.ts +16 -0
  100. package/dist/models/NanoGPTV1.js +2 -0
  101. package/dist/models/NanoGPTV2.d.ts +16 -0
  102. package/dist/models/NanoGPTV2.js +2 -0
  103. package/dist/models/config.d.ts +27 -0
  104. package/dist/models/config.js +37 -0
  105. package/dist/models/factory.d.ts +3 -0
  106. package/dist/models/factory.js +2 -0
  107. package/dist/models/model.d.ts +44 -0
  108. package/dist/models/model.js +2 -0
  109. package/dist/ops/adamAdjust.d.ts +2 -0
  110. package/dist/ops/adamAdjust.js +18 -0
  111. package/dist/ops/adamMoments.d.ts +2 -0
  112. package/dist/ops/adamMoments.js +16 -0
  113. package/dist/ops/add16.d.ts +2 -0
  114. package/dist/ops/add16.js +12 -0
  115. package/dist/ops/appendCache.d.ts +2 -0
  116. package/dist/ops/appendCache.js +25 -0
  117. package/dist/ops/attentionMask.d.ts +2 -0
  118. package/dist/ops/attentionMask.js +16 -0
  119. package/dist/ops/concat16.d.ts +2 -0
  120. package/dist/ops/concat16.js +8 -0
  121. package/dist/ops/cpu/adamAdjust.d.ts +1 -0
  122. package/dist/ops/cpu/adamAdjust.js +16 -0
  123. package/dist/ops/cpu/adamMoments.d.ts +1 -0
  124. package/dist/ops/cpu/adamMoments.js +16 -0
  125. package/dist/ops/cpu/appendCache.d.ts +1 -0
  126. package/dist/ops/cpu/appendCache.js +65 -0
  127. package/dist/ops/cpu/attentionMask.d.ts +1 -0
  128. package/dist/ops/cpu/attentionMask.js +16 -0
  129. package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
  130. package/dist/ops/cpu/fusedSoftmax.js +22 -0
  131. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  132. package/dist/ops/cpu/gatherSub.js +12 -0
  133. package/dist/ops/cpu/gelu.d.ts +1 -0
  134. package/dist/ops/cpu/gelu.js +36 -0
  135. package/dist/ops/cpu/matMul16.d.ts +1 -0
  136. package/dist/ops/cpu/matMul16.js +14 -0
  137. package/dist/ops/cpu/matMulGelu.d.ts +1 -0
  138. package/dist/ops/cpu/matMulGelu.js +41 -0
  139. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  140. package/dist/ops/cpu/matMulMul.js +20 -0
  141. package/dist/ops/cpu/mulDropout.d.ts +1 -0
  142. package/dist/ops/cpu/mulDropout.js +20 -0
  143. package/dist/ops/cpu/normRMS.d.ts +1 -0
  144. package/dist/ops/cpu/normRMS.js +35 -0
  145. package/dist/ops/cpu/qkv.d.ts +5 -0
  146. package/dist/ops/cpu/qkv.js +73 -0
  147. package/dist/ops/cpu/rope.d.ts +6 -0
  148. package/dist/ops/cpu/rope.js +81 -0
  149. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  150. package/dist/ops/cpu/scatterSub.js +12 -0
  151. package/dist/ops/dot16.d.ts +2 -0
  152. package/dist/ops/dot16.js +29 -0
  153. package/dist/ops/dropout.d.ts +2 -0
  154. package/dist/ops/dropout.js +11 -0
  155. package/dist/ops/dropout16.d.ts +2 -0
  156. package/dist/ops/dropout16.js +22 -0
  157. package/dist/ops/gatherSub.d.ts +2 -0
  158. package/dist/ops/gatherSub.js +13 -0
  159. package/dist/ops/gelu.d.ts +3 -0
  160. package/dist/ops/gelu.js +2 -0
  161. package/dist/ops/globalNorm.d.ts +2 -0
  162. package/dist/ops/globalNorm.js +19 -0
  163. package/dist/ops/grads/add16.d.ts +1 -0
  164. package/dist/ops/grads/add16.js +27 -0
  165. package/dist/ops/grads/attentionMask.d.ts +1 -0
  166. package/dist/ops/grads/attentionMask.js +26 -0
  167. package/dist/ops/grads/dropout16.d.ts +1 -0
  168. package/dist/ops/grads/dropout16.js +1 -0
  169. package/dist/ops/grads/gelu.d.ts +2 -0
  170. package/dist/ops/grads/gelu.js +2 -0
  171. package/dist/ops/grads/matMul16.d.ts +2 -0
  172. package/dist/ops/grads/matMul16.js +2 -0
  173. package/dist/ops/grads/matMulGelu.d.ts +1 -0
  174. package/dist/ops/grads/matMulGelu.js +22 -0
  175. package/dist/ops/grads/mul16.d.ts +1 -0
  176. package/dist/ops/grads/mul16.js +1 -0
  177. package/dist/ops/grads/normRMS.d.ts +3 -0
  178. package/dist/ops/grads/normRMS.js +37 -0
  179. package/dist/ops/grads/pack16.d.ts +2 -0
  180. package/dist/ops/grads/pack16.js +2 -0
  181. package/dist/ops/grads/qkv.d.ts +3 -0
  182. package/dist/ops/grads/qkv.js +46 -0
  183. package/dist/ops/grads/rope.d.ts +2 -0
  184. package/dist/ops/grads/rope.js +2 -0
  185. package/dist/ops/grads/softmax16.d.ts +2 -0
  186. package/dist/ops/grads/softmax16.js +23 -0
  187. package/dist/ops/grads/unpack16.d.ts +2 -0
  188. package/dist/ops/grads/unpack16.js +2 -0
  189. package/dist/ops/grads/utils.d.ts +4 -0
  190. package/dist/ops/grads/utils.js +12 -0
  191. package/dist/ops/log.d.ts +0 -0
  192. package/dist/ops/log.js +1 -0
  193. package/dist/ops/matMul16.d.ts +15 -0
  194. package/dist/ops/matMul16.js +2 -0
  195. package/dist/ops/matMulGelu.d.ts +3 -0
  196. package/dist/ops/matMulGelu.js +20 -0
  197. package/dist/ops/matMulMul.d.ts +2 -0
  198. package/dist/ops/matMulMul.js +16 -0
  199. package/dist/ops/mul16.d.ts +2 -0
  200. package/dist/ops/mul16.js +43 -0
  201. package/dist/ops/mulDrop.d.ts +2 -0
  202. package/dist/ops/mulDrop.js +15 -0
  203. package/dist/ops/normRMS.d.ts +2 -0
  204. package/dist/ops/normRMS.js +22 -0
  205. package/dist/ops/pack16.d.ts +2 -0
  206. package/dist/ops/pack16.js +2 -0
  207. package/dist/ops/qkv.d.ts +2 -0
  208. package/dist/ops/qkv.js +16 -0
  209. package/dist/ops/reshape16.d.ts +2 -0
  210. package/dist/ops/reshape16.js +33 -0
  211. package/dist/ops/rope.d.ts +3 -0
  212. package/dist/ops/rope.js +2 -0
  213. package/dist/ops/scatterSub.d.ts +2 -0
  214. package/dist/ops/scatterSub.js +13 -0
  215. package/dist/ops/slice16.d.ts +2 -0
  216. package/dist/ops/slice16.js +11 -0
  217. package/dist/ops/softmax16.d.ts +2 -0
  218. package/dist/ops/softmax16.js +9 -0
  219. package/dist/ops/sub16.d.ts +2 -0
  220. package/dist/ops/sub16.js +11 -0
  221. package/dist/ops/sum16.d.ts +2 -0
  222. package/dist/ops/sum16.js +13 -0
  223. package/dist/ops/transpose16.d.ts +3 -0
  224. package/dist/ops/transpose16.js +32 -0
  225. package/dist/ops/unpack16.d.ts +2 -0
  226. package/dist/ops/unpack16.js +2 -0
  227. package/dist/ops/webgl/adamAdjust.d.ts +1 -0
  228. package/dist/ops/webgl/adamAdjust.js +82 -0
  229. package/dist/ops/webgl/adamMoments.d.ts +1 -0
  230. package/dist/ops/webgl/adamMoments.js +44 -0
  231. package/dist/ops/webgl/appendCache.d.ts +1 -0
  232. package/dist/ops/webgl/appendCache.js +53 -0
  233. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  234. package/dist/ops/webgl/attentionMask.js +64 -0
  235. package/dist/ops/webgl/dropout16.d.ts +1 -0
  236. package/dist/ops/webgl/dropout16.js +12 -0
  237. package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
  238. package/dist/ops/webgl/fusedSoftmax.js +70 -0
  239. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  240. package/dist/ops/webgl/gatherSub.js +28 -0
  241. package/dist/ops/webgl/gelu.d.ts +2 -0
  242. package/dist/ops/webgl/gelu.js +48 -0
  243. package/dist/ops/webgl/log.d.ts +17 -0
  244. package/dist/ops/webgl/log.js +14 -0
  245. package/dist/ops/webgl/matMul16.d.ts +1 -0
  246. package/dist/ops/webgl/matMul16.js +37 -0
  247. package/dist/ops/webgl/matMulGelu.d.ts +21 -0
  248. package/dist/ops/webgl/matMulGelu.js +2 -0
  249. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  250. package/dist/ops/webgl/matMulMul.js +24 -0
  251. package/dist/ops/webgl/mulDropout.d.ts +1 -0
  252. package/dist/ops/webgl/mulDropout.js +32 -0
  253. package/dist/ops/webgl/normRMS.d.ts +1 -0
  254. package/dist/ops/webgl/normRMS.js +114 -0
  255. package/dist/ops/webgl/qkv.d.ts +1 -0
  256. package/dist/ops/webgl/qkv.js +54 -0
  257. package/dist/ops/webgl/rope.d.ts +1 -0
  258. package/dist/ops/webgl/rope.js +72 -0
  259. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  260. package/dist/ops/webgl/scatterSub.js +28 -0
  261. package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
  262. package/dist/ops/webgpu/adamAdjust.js +77 -0
  263. package/dist/ops/webgpu/adamMoments.d.ts +1 -0
  264. package/dist/ops/webgpu/adamMoments.js +76 -0
  265. package/dist/ops/webgpu/add16.d.ts +1 -0
  266. package/dist/ops/webgpu/add16.js +14 -0
  267. package/dist/ops/webgpu/appendCache.d.ts +1 -0
  268. package/dist/ops/webgpu/appendCache.js +130 -0
  269. package/dist/ops/webgpu/attentionMask.d.ts +1 -0
  270. package/dist/ops/webgpu/attentionMask.js +42 -0
  271. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  272. package/dist/ops/webgpu/attentionMask32_program.js +62 -0
  273. package/dist/ops/webgpu/clipScale.d.ts +1 -0
  274. package/dist/ops/webgpu/clipScale.js +45 -0
  275. package/dist/ops/webgpu/concat16.d.ts +19 -0
  276. package/dist/ops/webgpu/concat16.js +111 -0
  277. package/dist/ops/webgpu/dropout16.d.ts +1 -0
  278. package/dist/ops/webgpu/dropout16.js +59 -0
  279. package/dist/ops/webgpu/gatherSub.d.ts +1 -0
  280. package/dist/ops/webgpu/gatherSub.js +52 -0
  281. package/dist/ops/webgpu/gelu.d.ts +14 -0
  282. package/dist/ops/webgpu/gelu.js +147 -0
  283. package/dist/ops/webgpu/index.d.ts +0 -0
  284. package/dist/ops/webgpu/index.js +26 -0
  285. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  286. package/dist/ops/webgpu/matMul16.js +70 -0
  287. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  288. package/dist/ops/webgpu/matMul16_program.js +303 -0
  289. package/dist/ops/webgpu/mul16.d.ts +1 -0
  290. package/dist/ops/webgpu/mul16.js +14 -0
  291. package/dist/ops/webgpu/norm2.d.ts +1 -0
  292. package/dist/ops/webgpu/norm2.js +46 -0
  293. package/dist/ops/webgpu/normRMS.d.ts +1 -0
  294. package/dist/ops/webgpu/normRMS.js +26 -0
  295. package/dist/ops/webgpu/normRMS16_program.d.ts +10 -0
  296. package/dist/ops/webgpu/normRMS16_program.js +28 -0
  297. package/dist/ops/webgpu/normRMS32_program.d.ts +10 -0
  298. package/dist/ops/webgpu/normRMS32_program.js +28 -0
  299. package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
  300. package/dist/ops/webgpu/normRMSGrad.js +225 -0
  301. package/dist/ops/webgpu/pack16.d.ts +1 -0
  302. package/dist/ops/webgpu/pack16.js +21 -0
  303. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  304. package/dist/ops/webgpu/pack16_program.js +93 -0
  305. package/dist/ops/webgpu/qkv.d.ts +1 -0
  306. package/dist/ops/webgpu/qkv.js +64 -0
  307. package/dist/ops/webgpu/rope.d.ts +1 -0
  308. package/dist/ops/webgpu/rope.js +163 -0
  309. package/dist/ops/webgpu/scatterSub.d.ts +1 -0
  310. package/dist/ops/webgpu/scatterSub.js +53 -0
  311. package/dist/ops/webgpu/slice16.d.ts +7 -0
  312. package/dist/ops/webgpu/slice16.js +74 -0
  313. package/dist/ops/webgpu/softmax16.d.ts +17 -0
  314. package/dist/ops/webgpu/softmax16.js +18 -0
  315. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  316. package/dist/ops/webgpu/softmax16_program.js +89 -0
  317. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  318. package/dist/ops/webgpu/softmax16_subgroup_program.js +70 -0
  319. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  320. package/dist/ops/webgpu/softmax16grad.js +31 -0
  321. package/dist/ops/webgpu/sub16.d.ts +1 -0
  322. package/dist/ops/webgpu/sub16.js +14 -0
  323. package/dist/ops/webgpu/sum16.d.ts +1 -0
  324. package/dist/ops/webgpu/sum16.js +29 -0
  325. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  326. package/dist/ops/webgpu/transpose16.js +37 -0
  327. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  328. package/dist/ops/webgpu/transpose16_program.js +51 -0
  329. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  330. package/dist/ops/webgpu/transpose16_shared_program.js +79 -0
  331. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  332. package/dist/ops/webgpu/unpack16.js +60 -0
  333. package/dist/ops/webgpu/utils/binary_op.d.ts +35 -0
  334. package/dist/ops/webgpu/utils/binary_op.js +141 -0
  335. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  336. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  337. package/dist/ops/webgpu/utils/reductions.d.ts +43 -0
  338. package/dist/ops/webgpu/utils/reductions.js +263 -0
  339. package/dist/pack16-Ck-spx_F.js +39 -0
  340. package/dist/patches/webgpu_backend.d.ts +18 -0
  341. package/dist/patches/webgpu_backend.js +43 -0
  342. package/dist/patches/webgpu_base.d.ts +21 -0
  343. package/dist/patches/webgpu_base.js +22 -0
  344. package/dist/patches/webgpu_program.d.ts +36 -0
  345. package/dist/patches/webgpu_program.js +293 -0
  346. package/dist/pdf-UoDqCYzz.js +16726 -0
  347. package/dist/picomatch-3tUnMMbd.js +1063 -0
  348. package/dist/rope-CbeGlsV8.js +25 -0
  349. package/dist/selu_util-zkAx5doH.js +24 -0
  350. package/dist/shared-D1coEFea.js +1314 -0
  351. package/dist/shared-DOgWaqvL.js +5 -0
  352. package/dist/slice_util-Dgb3ANWI.js +208 -0
  353. package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
  354. package/dist/tokeniser/BaseTokeniser.d.ts +33 -0
  355. package/dist/tokeniser/BaseTokeniser.js +2 -0
  356. package/dist/tokeniser/CharTokeniser.d.ts +24 -0
  357. package/dist/tokeniser/CharTokeniser.js +92 -0
  358. package/dist/tokeniser/bpe.d.ts +28 -0
  359. package/dist/tokeniser/bpe.js +170 -0
  360. package/dist/tokeniser/messages.d.ts +61 -0
  361. package/dist/tokeniser/messages.js +0 -0
  362. package/dist/tokeniser/type.d.ts +34 -0
  363. package/dist/tokeniser/type.js +0 -0
  364. package/dist/training/AdamW.d.ts +36 -0
  365. package/dist/training/AdamW.js +128 -0
  366. package/dist/training/BasicTrainer.d.ts +63 -0
  367. package/dist/training/BasicTrainer.js +265 -0
  368. package/dist/training/DatasetBuilder.d.ts +26 -0
  369. package/dist/training/DatasetBuilder.js +2 -0
  370. package/dist/training/Evaluator.d.ts +19 -0
  371. package/dist/training/Evaluator.js +48 -0
  372. package/dist/training/LRScheduler.d.ts +12 -0
  373. package/dist/training/LRScheduler.js +38 -0
  374. package/dist/training/PreTrainer.d.ts +11 -0
  375. package/dist/training/PreTrainer.js +22 -0
  376. package/dist/training/SFTTrainer.d.ts +12 -0
  377. package/dist/training/SFTTrainer.js +24 -0
  378. package/dist/training/loss.d.ts +3 -0
  379. package/dist/training/loss.js +19 -0
  380. package/dist/training/orthoGrad.d.ts +2 -0
  381. package/dist/training/orthoGrad.js +10 -0
  382. package/dist/training/sparseCrossEntropy.d.ts +7 -0
  383. package/dist/training/sparseCrossEntropy.js +47 -0
  384. package/dist/training/tasks/ConversationTask.d.ts +18 -0
  385. package/dist/training/tasks/ConversationTask.js +38 -0
  386. package/dist/training/tasks/PretrainingTask.d.ts +17 -0
  387. package/dist/training/tasks/PretrainingTask.js +42 -0
  388. package/dist/training/tasks/StartSentenceTask.d.ts +18 -0
  389. package/dist/training/tasks/StartSentenceTask.js +45 -0
  390. package/dist/training/tasks/Task.d.ts +22 -0
  391. package/dist/training/tasks/Task.js +55 -0
  392. package/dist/training/tasks/splitter.d.ts +5 -0
  393. package/dist/training/tasks/splitter.js +18 -0
  394. package/dist/training/types.d.ts +78 -0
  395. package/dist/training/types.js +0 -0
  396. package/dist/training/validation.d.ts +17 -0
  397. package/dist/training/validation.js +2 -0
  398. package/dist/utilities/arrayClose.d.ts +1 -0
  399. package/dist/utilities/arrayClose.js +16 -0
  400. package/dist/utilities/datasetID.d.ts +2 -0
  401. package/dist/utilities/datasetID.js +18 -0
  402. package/dist/utilities/dummy.d.ts +9 -0
  403. package/dist/utilities/dummy.js +36 -0
  404. package/dist/utilities/multinomialCPU.d.ts +2 -0
  405. package/dist/utilities/multinomialCPU.js +9 -0
  406. package/dist/utilities/naming.d.ts +4 -0
  407. package/dist/utilities/naming.js +0 -0
  408. package/dist/utilities/packed.d.ts +4 -0
  409. package/dist/utilities/packed.js +13 -0
  410. package/dist/utilities/parameters.d.ts +11 -0
  411. package/dist/utilities/parameters.js +38 -0
  412. package/dist/utilities/performance.d.ts +2 -0
  413. package/dist/utilities/performance.js +16 -0
  414. package/dist/utilities/profile.d.ts +17 -0
  415. package/dist/utilities/profile.js +33 -0
  416. package/dist/utilities/safetensors.d.ts +3 -0
  417. package/dist/utilities/safetensors.js +53 -0
  418. package/dist/utilities/sentences.d.ts +5 -0
  419. package/dist/utilities/sentences.js +32 -0
  420. package/dist/utilities/tokenParse.d.ts +1 -0
  421. package/dist/utilities/tokenParse.js +17 -0
  422. package/dist/utilities/topP.d.ts +1 -0
  423. package/dist/utilities/topP.js +12 -0
  424. package/dist/utilities/waitForModel.d.ts +2 -0
  425. package/dist/utilities/waitForModel.js +12 -0
  426. package/dist/utilities/weights.d.ts +12 -0
  427. package/dist/utilities/weights.js +40 -0
  428. package/dist/utilities/yielder.d.ts +1 -0
  429. package/dist/utilities/yielder.js +7 -0
  430. package/dist/webgpu-Dt7BMzWz.js +525 -0
  431. package/dist/webgpu_program-WOyIVMlZ.js +392 -0
  432. package/dist/webgpu_util-B_F3SShA.js +106 -0
  433. package/package.json +1 -1
@@ -0,0 +1,265 @@
1
+ import { _n as e, di as t, kt as n, ni as r, oi as i, qt as a } from "../dist-BewPQWjc.js";
2
+ import { AdamWOptimizer as o } from "./AdamW.js";
3
+ import { calculateAccuracy as s, calculateLoss as c } from "./loss.js";
4
+ import l from "./Evaluator.js";
5
+ import u from "../utilities/profile.js";
6
+ import { createTensorStatistics as d } from "../checks/weights.js";
7
+ //#region lib/training/BasicTrainer.ts
8
+ var f = {
9
+ logInterval: 1,
10
+ maxEpochs: 100,
11
+ sftMode: "full",
12
+ batchSize: 32
13
+ }, p = {
14
+ learningRate: 3e-4,
15
+ beta1: .9,
16
+ beta2: .99,
17
+ epsilon: 1e-8,
18
+ weightDecay: .01,
19
+ warmupSteps: 100,
20
+ decayEpochs: 100,
21
+ epochSteps: 1e4,
22
+ minLearningRate: 1e-5,
23
+ lossScaling: 1
24
+ }, m = class {
25
+ tokenizer;
26
+ model;
27
+ optimizer;
28
+ running = !1;
29
+ lastState;
30
+ _gradientCheckpointing = !1;
31
+ _mixedPrecision = !1;
32
+ maskedLoss = !1;
33
+ optimizerConfig;
34
+ metrics = /* @__PURE__ */ new Set();
35
+ _labelSmoothing = 0;
36
+ _layerDrop = 0;
37
+ _dropout = 0;
38
+ constructor(e, t, n, r) {
39
+ this.tokenizer = t, this.model = e, this.optimizerConfig = {
40
+ ...p,
41
+ ...n,
42
+ lossScaling: e.lossScaling
43
+ };
44
+ let i = r || new o(this.optimizerConfig);
45
+ r && r.updateConfig(this.optimizerConfig), this.optimizer = i;
46
+ }
47
+ setLossMasking() {
48
+ this.maskedLoss = !0;
49
+ }
50
+ setGradientCheckpointing(e) {
51
+ this._gradientCheckpointing = e;
52
+ }
53
+ setMixedPrecision(e) {
54
+ this._mixedPrecision = e;
55
+ }
56
+ setLabelSmoothing(e) {
57
+ this._labelSmoothing = e;
58
+ }
59
+ setDropout(e) {
60
+ this._dropout = e;
61
+ }
62
+ setLayerDrop(e) {
63
+ this._layerDrop = e;
64
+ }
65
+ setLearningRate(e) {
66
+ this.optimizerConfig.learningRate = e, this.updateOptimizer();
67
+ }
68
+ setMetrics(e) {
69
+ this.metrics = new Set(e);
70
+ }
71
+ reset() {
72
+ this.lastState = void 0, this.running = !1;
73
+ }
74
+ stop() {
75
+ this.running = !1;
76
+ }
77
+ get isRunning() {
78
+ return this.running;
79
+ }
80
+ getOptimizer() {
81
+ return this.optimizer;
82
+ }
83
+ updateOptimizer(e) {
84
+ e && (this.optimizerConfig = {
85
+ ...this.optimizerConfig,
86
+ ...e
87
+ }), this.optimizer.updateConfig(this.optimizerConfig);
88
+ }
89
+ resumeFromLog(e) {
90
+ (!this.lastState || this.lastState.step === 0) && (this.lastState = {
91
+ losses: [],
92
+ validationLosses: [],
93
+ logStartTime: 0,
94
+ step: e.step,
95
+ lastLoss: e.trainingMetrics.loss,
96
+ totalSteps: e.step,
97
+ trainingDuration: e.duration
98
+ });
99
+ }
100
+ trainStep(n, o, l = !1, u = !1) {
101
+ return t(() => {
102
+ this.model.getProfiler()?.startMemory();
103
+ let { xs: t, ys: d } = o, { value: f, grads: p } = a(() => {
104
+ let r = this.model.forward({
105
+ training: !0,
106
+ checkpointing: this._gradientCheckpointing,
107
+ mixedPrecision: this._mixedPrecision,
108
+ dropout: this._dropout,
109
+ layerDrop: this._layerDrop,
110
+ ropePositionOffset: 0
111
+ }, t), a = c(r, d, this.maskedLoss, !1, this._labelSmoothing);
112
+ this.metrics.has("accuracy") && (n.accuracy = s(r, d), i(n.accuracy)), r.dispose();
113
+ let o = a.mul(e(this.optimizerConfig.lossScaling));
114
+ return a.dispose(), o;
115
+ });
116
+ if (l) this.model.getProfiler()?.endMemory("Training");
117
+ else {
118
+ let e = this.optimizer.applyGradients(p);
119
+ this.metrics.has("gradientNorm") ? (n.gradientNorm = e, i(e)) : (n.gradientNorm = void 0, e.dispose());
120
+ let t = Object.keys(p);
121
+ this.model.weightStore.touchVariables(t), this.model.getProfiler()?.endMemory("Training"), u ? (n.gradients = p, Object.values(p).forEach((e) => i(e))) : r(p);
122
+ }
123
+ return f.mul(e(1 / this.optimizerConfig.lossScaling));
124
+ });
125
+ }
126
+ async dummyPass() {
127
+ let e = n([1, this.model.config.blockSize], "int32"), t = n([1, this.model.config.blockSize], "int32");
128
+ try {
129
+ let n = this.trainStep({}, {
130
+ xs: e,
131
+ ys: t
132
+ }, !0);
133
+ await n.data(), n.dispose();
134
+ } catch (e) {
135
+ console.error("Error during dummy pass:", e);
136
+ } finally {
137
+ e.dispose(), t.dispose();
138
+ }
139
+ }
140
+ dispose() {
141
+ this.optimizer && this.optimizer.dispose();
142
+ }
143
+ createEmptyState() {
144
+ return {
145
+ step: 0,
146
+ lastLoss: 1e6,
147
+ totalSteps: 0,
148
+ losses: [],
149
+ validationLosses: [],
150
+ logStartTime: 0,
151
+ trainingDuration: 0,
152
+ ...this.lastState || {}
153
+ };
154
+ }
155
+ async stepDataset(e, t, n) {
156
+ let { logInterval: i = 10 } = {
157
+ ...f,
158
+ ...t
159
+ };
160
+ t.metrics && this.setMetrics(t.metrics);
161
+ let a = Date.now(), o = this.createEmptyState();
162
+ this.lastState = o, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new u())), this.running = !0, o.logStartTime = a;
163
+ let s = n ? new l(this.model, n, this.maskedLoss) : void 0, c = await e.iterator();
164
+ try {
165
+ for (; this.running;) {
166
+ let e = await c.next();
167
+ if (e.done) break;
168
+ let n = e.value, r = this.trainStep(o, n, !1);
169
+ if (t.debug) {
170
+ let e = (await r.data())[0];
171
+ if (isNaN(e) || !isFinite(e)) throw console.error("Invalid loss value:", e), console.error("Batch xs:", n.xs.toString()), console.error("Batch ys:", n.ys.toString()), console.error("State:", o), Error("Loss is NaN or Infinity");
172
+ console.log(`Step ${o.step}: Loss = ${e}`);
173
+ }
174
+ n.xs.dispose(), n.ys.dispose(), o.step++, o.totalSteps++, o.step % i === 0 ? await this.performLogging(r, n.xs.shape[0], t, s) : (o.gradientNorm &&= (o.gradientNorm.dispose(), void 0), o.accuracy &&= (o.accuracy.dispose(), void 0)), r.dispose();
175
+ }
176
+ } catch (e) {
177
+ throw console.error("Training error:", e), e;
178
+ }
179
+ throw this.model.trainingState = {
180
+ steps: o.totalSteps,
181
+ learningRate: this.optimizer.lr,
182
+ batchSize: t.batchSize || 32,
183
+ loss: o.lastLoss,
184
+ tokensProcessed: o.totalSteps * (t.batchSize || 32) * this.model.config.blockSize,
185
+ duration: o.trainingDuration
186
+ }, r(), this.running = !1, Error("No log returned before training stopped.");
187
+ }
188
+ async performLogging(e, t, n, r) {
189
+ let i = n?.onStep, a = this.metrics.has("gradientStatistics"), o = (await e.data())[0], s = this.lastState;
190
+ s.lastLoss = o, s.trainingDuration += Date.now() - s.logStartTime;
191
+ let c = s.totalSteps * t * this.model.config.blockSize, l = {
192
+ trainingMetrics: {
193
+ loss: s.lastLoss,
194
+ perplexity: this.metrics.has("perplexity") ? Math.exp(s.lastLoss) : void 0,
195
+ accuracy: s.accuracy ? (await s.accuracy.data())[0] : void 0
196
+ },
197
+ step: s.step,
198
+ time: Date.now() - s.logStartTime,
199
+ gradientNorm: s.gradientNorm ? (await s.gradientNorm.data())[1] : void 0,
200
+ batchSize: t,
201
+ learningRate: this.metrics.has("learningRate") ? this.optimizer.lr : void 0,
202
+ duration: s.trainingDuration,
203
+ totalTokens: c,
204
+ tokensPerSecond: c / (s.trainingDuration / 1e3),
205
+ memoryUsage: this.metrics.has("memoryUsage") ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
206
+ };
207
+ if (s.gradientNorm &&= (s.gradientNorm.dispose(), void 0), s.accuracy &&= (s.accuracy.dispose(), void 0), this.model.trainingState = {
208
+ steps: s.totalSteps,
209
+ learningRate: this.optimizer.lr,
210
+ batchSize: t,
211
+ loss: s.lastLoss,
212
+ tokensProcessed: c,
213
+ duration: s.trainingDuration
214
+ }, a && s.gradients) {
215
+ let e = /* @__PURE__ */ new Map();
216
+ for (let [t, n] of Object.entries(s.gradients)) e.set(t, await d(n)), n.dispose();
217
+ l.gradientMetrics = e;
218
+ }
219
+ if (r) try {
220
+ let e = await r.evaluate(5);
221
+ Array.isArray(e) ? l.validationMetrics = {
222
+ loss: e[0].loss,
223
+ accuracy: e[0].accuracy
224
+ } : (s.validationLosses.push(e.loss), l.validationMetrics = {
225
+ accuracy: e.accuracy,
226
+ loss: e.loss,
227
+ perplexity: this.metrics.has("perplexity") ? Math.exp(e.loss) : void 0
228
+ });
229
+ } catch (e) {
230
+ console.error("Validation error:", e);
231
+ }
232
+ i && await i(l), s.logStartTime = Date.now();
233
+ }
234
+ async trainOnDataset(e, t, n) {
235
+ let { logInterval: i = 10, maxEpochs: a = Infinity } = {
236
+ ...f,
237
+ ...t
238
+ }, o = a * (t?.epochSteps || 1e3);
239
+ t.metrics && this.setMetrics(t.metrics);
240
+ let s = Date.now(), c = this.createEmptyState();
241
+ this.lastState = c, await this.dummyPass(), t?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new u())), this.running = !0, c.logStartTime = s;
242
+ let d = n ? new l(this.model, n, this.maskedLoss) : void 0, p = await e.iterator();
243
+ try {
244
+ for (; this.running;) {
245
+ let e = await p.next();
246
+ if (e.done) break;
247
+ let n = e.value, r = c.step % i === 0, a = (t?.metrics?.includes("gradientStatistics") || !1) && r, s = this.trainStep(c, n, !1, a);
248
+ if (t.debug) {
249
+ let e = (await s.data())[0];
250
+ if (isNaN(e) || !isFinite(e)) throw console.error("Invalid loss value:", e), console.error("Batch xs:", await n.xs.array()), console.error("Batch ys:", await n.ys.array()), console.error("State:", c), Error("Loss is NaN or Infinity");
251
+ console.log(`Step ${c.step}: Loss = ${e}`);
252
+ }
253
+ n.xs.dispose(), n.ys.dispose(), c.step++, c.totalSteps++, r ? await this.performLogging(s, n.xs.shape[0], t, d) : (c.gradientNorm &&= (c.gradientNorm.dispose(), void 0), c.accuracy &&= (c.accuracy.dispose(), void 0)), s.dispose(), c.step >= o && this.stop();
254
+ }
255
+ } catch (e) {
256
+ throw console.error("Training error:", e), r(), e;
257
+ }
258
+ return r(), this.running = !1, {
259
+ losses: c.losses,
260
+ validationLosses: c.validationLosses
261
+ };
262
+ }
263
+ };
264
+ //#endregion
265
+ export { m as default };
@@ -0,0 +1,26 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { Conversation, ITokeniser } from '../tokeniser/type';
3
+ import { Dataset } from '@tensorflow/tfjs-data';
4
+ export declare const PAGE_FACTOR = 8;
5
+ export declare function flattenTokens(textData: Conversation[][], tokenizer: ITokeniser): Uint16Array;
6
+ export declare function flattenTokensWithMask(textData: Conversation[][], tokenizer: ITokeniser): {
7
+ tokens: Uint16Array;
8
+ mask: Uint8Array;
9
+ };
10
+ export declare function shuffle(array: Uint32Array): Uint32Array;
11
+ export interface DatasetState {
12
+ shuffledIndexes: Uint32Array;
13
+ step: number;
14
+ }
15
+ export declare class DatasetBuilder {
16
+ tokenizer: ITokeniser;
17
+ blockSize: number;
18
+ constructor(tokenizer: ITokeniser, blockSize?: number);
19
+ createTextDataset(flatTokens: Uint16Array, batchSize?: number, indexes?: Uint32Array, mask?: Uint8Array, ignoreIndex?: number): Promise<{
20
+ dataset: Dataset<{
21
+ xs: Tensor;
22
+ ys: Tensor;
23
+ }>;
24
+ state: DatasetState;
25
+ }>;
26
+ }
@@ -0,0 +1,2 @@
1
+ import { a as e, i as t, n, r, t as i } from "../DatasetBuilder-DgURD85T.js";
2
+ export { i as DatasetBuilder, n as PAGE_FACTOR, r as flattenTokens, t as flattenTokensWithMask, e as shuffle };
@@ -0,0 +1,19 @@
1
+ import { Dataset } from '@tensorflow/tfjs-data';
2
+ import { TensorContainer } from '@tensorflow/tfjs-core';
3
+ import { default as Model, ModelForwardAttributes } from '../../models/model';
4
+ interface Result {
5
+ loss: number;
6
+ accuracy: number;
7
+ }
8
+ export default class Evaluator {
9
+ private model;
10
+ private iterator?;
11
+ private xs?;
12
+ private ys?;
13
+ private masked;
14
+ constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer>, masked?: boolean);
15
+ dispose(): void;
16
+ private calculateBatchLoss;
17
+ evaluate(maxBatches?: number): Promise<Result | Result[]>;
18
+ }
19
+ export {};
@@ -0,0 +1,48 @@
1
+ import { di as e } from "../dist-BewPQWjc.js";
2
+ import { calculateAccuracy as t, calculateLoss as n } from "./loss.js";
3
+ //#region lib/training/Evaluator.ts
4
+ var r = class {
5
+ model;
6
+ iterator;
7
+ xs;
8
+ ys;
9
+ masked = !1;
10
+ constructor(e, t, n) {
11
+ this.model = e, this.masked = !!n, this.iterator = t.iterator();
12
+ }
13
+ dispose() {
14
+ this.xs && this.xs.dispose(), this.ys && this.ys.dispose();
15
+ }
16
+ async calculateBatchLoss(r, i, a, o) {
17
+ let [s, c] = e(() => {
18
+ let e = this.model.forward({ training: !1 }, r), s = n(e, i, o, a), c = t(e, i);
19
+ return e.dispose(), [s, c];
20
+ }), l = await s.array(), u = await c.array(), d = l, f = u;
21
+ return c.dispose(), s.dispose(), Array.isArray(d) ? d.map((e) => ({
22
+ loss: e,
23
+ accuracy: f
24
+ })) : {
25
+ loss: d,
26
+ accuracy: f
27
+ };
28
+ }
29
+ async evaluate(e = 100) {
30
+ let t = 0, n = 0, r = 0;
31
+ if (this.iterator) {
32
+ let i = await this.iterator;
33
+ for (let a = 0; a < e; a++) {
34
+ let e = await i.next();
35
+ if (e.done) break;
36
+ let { xs: a, ys: o } = e.value, s = await this.calculateBatchLoss(a, o, !1, this.masked);
37
+ a.dispose(), o.dispose(), t += s.loss, n += s.accuracy, r++;
38
+ }
39
+ return {
40
+ loss: t / r,
41
+ accuracy: n / r
42
+ };
43
+ } else if (this.xs && this.ys) return this.calculateBatchLoss(this.xs, this.ys, !0, !0);
44
+ throw Error("No data available for evaluation");
45
+ }
46
+ };
47
+ //#endregion
48
+ export { r as default };
@@ -0,0 +1,12 @@
1
+ import { LRSchedulerConfig } from './types';
2
+ export default class LRScheduler {
3
+ protected learningRate: number;
4
+ private config;
5
+ private step;
6
+ private startLearningRate;
7
+ constructor(learningRate: number, config: LRSchedulerConfig);
8
+ serializeConfig(): LRSchedulerConfig;
9
+ updateConfig(newConfig: Partial<LRSchedulerConfig>, learningRate?: number): void;
10
+ get lr(): number;
11
+ getNextLR(): number;
12
+ }
@@ -0,0 +1,38 @@
1
+ //#region lib/training/LRScheduler.ts
2
+ var e = class {
3
+ learningRate;
4
+ config;
5
+ step = 0;
6
+ startLearningRate;
7
+ constructor(e, t) {
8
+ this.learningRate = e, this.config = t, this.startLearningRate = e, t.step !== void 0 && (this.step = t.step);
9
+ }
10
+ serializeConfig() {
11
+ return {
12
+ ...this.config,
13
+ step: this.step
14
+ };
15
+ }
16
+ updateConfig(e, t) {
17
+ this.config = {
18
+ ...this.config,
19
+ ...e
20
+ }, t !== void 0 && (this.startLearningRate = t);
21
+ }
22
+ get lr() {
23
+ return this.learningRate;
24
+ }
25
+ getNextLR() {
26
+ let e = this.step;
27
+ if (this.config.warmupSteps > 0 && e < this.config.warmupSteps) {
28
+ let t = (e + 1) / this.config.warmupSteps, n = this.startLearningRate * t;
29
+ return this.learningRate = n, this.step++, n;
30
+ }
31
+ let t = this.config.epochSteps * this.config.decayEpochs;
32
+ if (e >= t || t <= this.config.warmupSteps) return this.learningRate = this.config.minLearningRate, this.step++, this.config.minLearningRate;
33
+ let n = (e - this.config.warmupSteps) / (t - this.config.warmupSteps), r = .5 * (1 + Math.cos(Math.PI * n)), i = this.config.minLearningRate + r * (this.startLearningRate - this.config.minLearningRate);
34
+ return this.learningRate = i, this.step++, i;
35
+ }
36
+ };
37
+ //#endregion
38
+ export { e as default };
@@ -0,0 +1,11 @@
1
+ import { default as Model, ModelForwardAttributes } from '../../models/model';
2
+ import { default as BasicTrainer } from './BasicTrainer';
3
+ import { ITokeniser } from '../../tokeniser/type';
4
+ import { DatasetBuilder } from './DatasetBuilder';
5
+ import { AdamWOptimizer } from './AdamW';
6
+ import { AdamWOptimizerConfig } from './types';
7
+ export default class PreTrainer extends BasicTrainer {
8
+ tokenizer: ITokeniser;
9
+ datasetBuilder: DatasetBuilder;
10
+ constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, optConfig?: Partial<AdamWOptimizerConfig>, optimizer?: AdamWOptimizer);
11
+ }
@@ -0,0 +1,22 @@
1
+ import { t as e } from "../DatasetBuilder-DgURD85T.js";
2
+ import t from "./BasicTrainer.js";
3
+ //#region lib/training/PreTrainer.ts
4
+ var n = {
5
+ decayEpochs: 100,
6
+ epochSteps: 1e4,
7
+ warmupSteps: 1e3,
8
+ minLearningRate: 3e-5,
9
+ weightDecay: .1,
10
+ learningRate: 3e-4
11
+ }, r = class extends t {
12
+ tokenizer;
13
+ datasetBuilder;
14
+ constructor(t, r, i, a) {
15
+ super(t, r, {
16
+ ...n,
17
+ ...i
18
+ }, a), this.tokenizer = r, this.optimizerConfig.minLearningRate = i?.minLearningRate ?? this.optimizerConfig.learningRate / 20, this.updateOptimizer(), this.datasetBuilder = new e(r, t.config.blockSize);
19
+ }
20
+ };
21
+ //#endregion
22
+ export { r as default };
@@ -0,0 +1,12 @@
1
+ import { default as Model, ModelForwardAttributes } from '../../models/model';
2
+ import { default as BasicTrainer } from './BasicTrainer';
3
+ import { ITokeniser } from '../../tokeniser/type';
4
+ import { AdamWOptimizer } from './AdamW';
5
+ import { AdamWOptimizerConfig } from './types';
6
+ import { DatasetBuilder } from './DatasetBuilder';
7
+ export default class SFTTrainer extends BasicTrainer {
8
+ tokenizer: ITokeniser;
9
+ datasetBuilder: DatasetBuilder;
10
+ loraName?: string;
11
+ constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, optConfig?: Partial<AdamWOptimizerConfig>, optimizer?: AdamWOptimizer);
12
+ }
@@ -0,0 +1,24 @@
1
+ import { t as e } from "../DatasetBuilder-DgURD85T.js";
2
+ import t from "./BasicTrainer.js";
3
+ //#region lib/training/SFTTrainer.ts
4
+ var n = {
5
+ decayEpochs: 100,
6
+ epochSteps: 1e4,
7
+ warmupSteps: 100,
8
+ minLearningRate: 1e-5,
9
+ weightDecay: .1,
10
+ beta2: .95,
11
+ learningRate: 3e-4
12
+ }, r = class extends t {
13
+ tokenizer;
14
+ datasetBuilder;
15
+ loraName;
16
+ constructor(t, r, i, a) {
17
+ super(t, r, {
18
+ ...n,
19
+ ...i
20
+ }, a), this.tokenizer = r, this.optimizerConfig.minLearningRate = i?.minLearningRate ?? this.optimizerConfig.learningRate / 20, this.updateOptimizer(), this.datasetBuilder = new e(r, t.config.blockSize), this.maskedLoss = !0;
21
+ }
22
+ };
23
+ //#endregion
24
+ export { r as default };
@@ -0,0 +1,3 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function calculateLoss(logits: Tensor, targets: Tensor, masked?: boolean, keepBatch?: boolean, labelSmoothing?: number): Tensor;
3
+ export declare function calculateAccuracy(logits: Tensor, targets: Tensor): Tensor;
@@ -0,0 +1,19 @@
1
+ import { createSoftmaxCrossEntropyWithGrad as e } from "./sparseCrossEntropy.js";
2
+ //#region lib/training/loss.ts
3
+ function t(t, n, r, i, a) {
4
+ try {
5
+ return e(r, i, a && a > 0 ? a : void 0)(t, n);
6
+ } catch (e) {
7
+ throw console.error("Error computing loss:", e), Error(`Loss computation failed: ${e}`);
8
+ }
9
+ }
10
+ function n(e, t) {
11
+ try {
12
+ let n = e.argMax(-1), r = n.equal(t).cast("float32"), i = r.mean();
13
+ return n.dispose(), r.dispose(), i;
14
+ } catch (e) {
15
+ throw console.error("Error computing accuracy:", e), Error(`Accuracy computation failed: ${e}`);
16
+ }
17
+ }
18
+ //#endregion
19
+ export { n as calculateAccuracy, t as calculateLoss };
@@ -0,0 +1,2 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function orthogonalizeGradient(weight: Tensor, gradient: Tensor, epsilon: number): Tensor;
@@ -0,0 +1,10 @@
1
+ import { di as e } from "../dist-BewPQWjc.js";
2
+ //#region lib/training/orthoGrad.ts
3
+ function t(t, n, r) {
4
+ return e(() => {
5
+ let e = t.reshape([-1]), i = n.reshape([-1]), a = e.mul(e).sum().add(r), o = e.mul(i).sum().div(a), s = i.sub(e.mul(o)), c = i.norm(), l = s.norm().add(r);
6
+ return s.mul(c.div(l)).reshape(n.shape);
7
+ });
8
+ }
9
+ //#endregion
10
+ export { t as orthogonalizeGradient };
@@ -0,0 +1,7 @@
1
+ import * as tf from '@tensorflow/tfjs-core';
2
+ /**
3
+ * Numerically stable sparse cross-entropy with gradient support
4
+ * This version handles potential numerical issues better
5
+ */
6
+ export declare function sparseSoftmaxCrossEntropy(logits: tf.Tensor, labels: tf.Tensor, validMask?: tf.Tensor, keepBatch?: boolean, originalBatchShape?: number[], labelSmoothing?: number): tf.Tensor;
7
+ export declare function createSoftmaxCrossEntropyWithGrad(masked?: boolean, keepBatch?: boolean, labelSmoothing?: number): (...args: tf.Tensor[]) => tf.Tensor<tf.Rank>;
@@ -0,0 +1,47 @@
1
+ import { At as e, Bt as t, Gr as n, Nn as r, Pn as i, Rt as a, St as o, Wr as s, Wt as c, Y as l, _n as u, bn as d, di as f, mn as p, qr as m } from "../dist-BewPQWjc.js";
2
+ import { gatherSub as h } from "../ops/gatherSub.js";
3
+ import { scatterSub as g } from "../ops/scatterSub.js";
4
+ //#region lib/training/sparseCrossEntropy.ts
5
+ function _(r, i, o, c, l, u = 0) {
6
+ return f(() => {
7
+ let f = r.shape[r.shape.length - 1], g = l || r.shape.slice(0, -1), _ = g.reduce((e, t) => e * t, 1), v = r.shape.length > 2 ? r.reshape([_, f]) : r, y = i.shape.length > 1 ? i.reshape([_]).cast("int32") : i.cast("int32"), b = t(v, d(v, -1, !0)), x = a(b, -1), S = h(x, y, b), C;
8
+ if (u > 0) {
9
+ let n = t(x, e(b, -1));
10
+ C = m(s(S, 1 - u), s(n, u));
11
+ } else C = S;
12
+ if (o) if (C = s(C, o), c) {
13
+ let e = p(o.reshape(g), -1);
14
+ C = n(p(C.reshape(g), -1), e);
15
+ } else {
16
+ let e = p(o);
17
+ C = n(p(C), e);
18
+ }
19
+ else C = c ? e(C.reshape(g), -1) : e(C);
20
+ return C;
21
+ });
22
+ }
23
+ function v(e, n, a = 0) {
24
+ return c((c, d, m) => {
25
+ let h = c.shape[c.shape.length - 1], v = c.shape.slice(0, -1), y = v.reduce((e, t) => e * t, 1), b = c.reshape([y, h]), x = d.reshape([y]).cast("int32"), S, C = null;
26
+ if (e) {
27
+ let e = u(65535, "int32"), t = o(x, e);
28
+ C = t.cast("float32"), S = i(t, x, r(x)), e.dispose(), t.dispose();
29
+ } else S = x;
30
+ let w = _(b, S, C || void 0, n, v, a);
31
+ return m(C ? [
32
+ b,
33
+ S,
34
+ C
35
+ ] : [b, S]), b.dispose(), x.dispose(), {
36
+ value: w,
37
+ gradFunc: (n, i) => f(() => {
38
+ let o = i[0], f = i[1], m = e ? i[2] : void 0, _ = l(o), v = m ? p(m) : u(o.shape[0], "float32"), y = n.div(v).broadcastTo([o.shape[0]]), b = m && e ? s(y, m) : y, x;
39
+ x = a > 0 ? g(t(_, a / h), f, s(b, 1 - a)) : g(_, f, b);
40
+ let S = r(d);
41
+ return [x.reshape(c.shape), S];
42
+ })
43
+ };
44
+ });
45
+ }
46
+ //#endregion
47
+ export { v as createSoftmaxCrossEntropyWithGrad, _ as sparseSoftmaxCrossEntropy };
@@ -0,0 +1,18 @@
1
+ import { Conversation, ITokeniser } from '../../../main';
2
+ import { Task } from './Task';
3
+ export default class ConversationTask extends Task {
4
+ private rawConvo;
5
+ private shuffledIndices;
6
+ private index;
7
+ get length(): number;
8
+ constructor(conversations: Conversation[][]);
9
+ hasMoreConversations(): boolean;
10
+ nextConversation(): Conversation[] | null;
11
+ nextTokens(tokeniser: ITokeniser): number[] | null;
12
+ nextTokens(tokeniser: ITokeniser, masking: boolean): {
13
+ tokens: number[];
14
+ mask: boolean[];
15
+ } | null;
16
+ shuffle(): void;
17
+ estimateTokens(tokeniser: ITokeniser): Promise<number>;
18
+ }
@@ -0,0 +1,38 @@
1
+ import { a as e } from "../../DatasetBuilder-DgURD85T.js";
2
+ import { Task as t } from "./Task.js";
3
+ //#region lib/training/tasks/ConversationTask.ts
4
+ var n = class extends t {
5
+ rawConvo;
6
+ shuffledIndices = null;
7
+ index = 0;
8
+ get length() {
9
+ return this.rawConvo.length;
10
+ }
11
+ constructor(e) {
12
+ super(), this.rawConvo = e;
13
+ }
14
+ hasMoreConversations() {
15
+ return this.index < this.rawConvo.length;
16
+ }
17
+ nextConversation() {
18
+ if (this.index >= this.rawConvo.length) return null;
19
+ let e = this.rawConvo[this.shuffledIndices ? this.shuffledIndices[this.index] : this.index];
20
+ return this.index++, e;
21
+ }
22
+ nextTokens(e, t) {
23
+ let n = this.nextConversation();
24
+ return n ? e.encodeConversation(n, !1, t) : null;
25
+ }
26
+ shuffle() {
27
+ if (!this.shuffledIndices) {
28
+ this.shuffledIndices = new Uint32Array(this.rawConvo.length);
29
+ for (let e = 0; e < this.rawConvo.length; e++) this.shuffledIndices[e] = e;
30
+ }
31
+ e(this.shuffledIndices), this.index = 0;
32
+ }
33
+ async estimateTokens(e) {
34
+ return e.encodeConversation(this.rawConvo[0]).length * this.length;
35
+ }
36
+ };
37
+ //#endregion
38
+ export { n as default };
@@ -0,0 +1,17 @@
1
+ import { Conversation, ITokeniser } from '../../../main';
2
+ import { Task } from './Task';
3
+ export default class PretrainingTask extends Task {
4
+ private rawText;
5
+ private index;
6
+ get length(): number;
7
+ constructor(texts: string[]);
8
+ hasMoreConversations(): boolean;
9
+ nextConversation(): Conversation[] | null;
10
+ nextTokens(tokeniser: ITokeniser): number[] | null;
11
+ nextTokens(tokeniser: ITokeniser, masking: boolean): {
12
+ tokens: number[];
13
+ mask: boolean[];
14
+ } | null;
15
+ shuffle(): void;
16
+ estimateTokens(tokeniser: ITokeniser): Promise<number>;
17
+ }