@genai-fi/nanogpt 0.20.0 → 0.20.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (433) hide show
  1. package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
  2. package/dist/DatasetBuilder-DgURD85T.js +712 -0
  3. package/dist/Generator.d.ts +82 -0
  4. package/dist/Generator.js +2 -0
  5. package/dist/RealDiv-DBu0FQqT.js +362 -0
  6. package/dist/Reshape-CABOPB9d.js +94 -0
  7. package/dist/Reshape-DqO3r8BC.js +17 -0
  8. package/dist/TeachableLLM.d.ts +70 -0
  9. package/dist/TeachableLLM.js +2 -0
  10. package/dist/Trainer.d.ts +43 -0
  11. package/dist/Trainer.js +2 -0
  12. package/dist/backend.d.ts +2 -0
  13. package/dist/backend.js +13 -0
  14. package/dist/backend_util-Cg-roD1p.js +399 -0
  15. package/dist/binary_op_util-CrYk9LXL.js +103 -0
  16. package/dist/checks/appendCache.d.ts +1 -0
  17. package/dist/checks/appendCache.js +55 -0
  18. package/dist/checks/attentionMask.d.ts +1 -0
  19. package/dist/checks/attentionMask.js +56 -0
  20. package/dist/checks/check.d.ts +9 -0
  21. package/dist/checks/check.js +32 -0
  22. package/dist/checks/gelu.d.ts +1 -0
  23. package/dist/checks/gelu.js +46 -0
  24. package/dist/checks/index.d.ts +26 -0
  25. package/dist/checks/index.js +28 -0
  26. package/dist/checks/matMulGelu.d.ts +1 -0
  27. package/dist/checks/matMulGelu.js +84 -0
  28. package/dist/checks/normRMS.d.ts +1 -0
  29. package/dist/checks/normRMS.js +28 -0
  30. package/dist/checks/normRMSGrad.d.ts +1 -0
  31. package/dist/checks/normRMSGrad.js +22 -0
  32. package/dist/checks/packUnpack.d.ts +1 -0
  33. package/dist/checks/packUnpack.js +46 -0
  34. package/dist/checks/qkv.d.ts +1 -0
  35. package/dist/checks/qkv.js +34 -0
  36. package/dist/checks/rope.d.ts +1 -0
  37. package/dist/checks/rope.js +30 -0
  38. package/dist/checks/weights.d.ts +14 -0
  39. package/dist/checks/weights.js +27 -0
  40. package/dist/chunk-BPntVaq0.js +23 -0
  41. package/dist/complex_util-CkazZsaH.js +60 -0
  42. package/dist/concat_util-CWDZCBlA.js +19 -0
  43. package/dist/data/docx.d.ts +2 -0
  44. package/dist/data/docx.js +3046 -0
  45. package/dist/data/pdf.d.ts +2 -0
  46. package/dist/data/pdf.js +17 -0
  47. package/dist/data/textLoader.d.ts +7 -0
  48. package/dist/data/textLoader.js +613 -0
  49. package/dist/dist-BewPQWjc.js +7572 -0
  50. package/dist/dist-DVmq73nz.js +8775 -0
  51. package/dist/dist-DXwIvKxl.js +896 -0
  52. package/dist/dist-VEU5mfO0.js +7545 -0
  53. package/dist/gelu-Bf1HW1RY.js +27 -0
  54. package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
  55. package/dist/inference/types.d.ts +16 -0
  56. package/dist/inference/types.js +0 -0
  57. package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
  58. package/dist/layers/BaseLayer.d.ts +44 -0
  59. package/dist/layers/BaseLayer.js +76 -0
  60. package/dist/layers/CausalSelfAttention.d.ts +39 -0
  61. package/dist/layers/CausalSelfAttention.js +99 -0
  62. package/dist/layers/LoRA.d.ts +14 -0
  63. package/dist/layers/LoRA.js +48 -0
  64. package/dist/layers/MLP.d.ts +17 -0
  65. package/dist/layers/MLP.js +34 -0
  66. package/dist/layers/PositionEmbedding.d.ts +8 -0
  67. package/dist/layers/PositionEmbedding.js +27 -0
  68. package/dist/layers/RMSNorm.d.ts +12 -0
  69. package/dist/layers/RMSNorm.js +20 -0
  70. package/dist/layers/RoPECache.d.ts +18 -0
  71. package/dist/layers/RoPECache.js +337 -0
  72. package/dist/layers/TiedEmbedding.d.ts +13 -0
  73. package/dist/layers/TiedEmbedding.js +32 -0
  74. package/dist/layers/TransformerBlock.d.ts +27 -0
  75. package/dist/layers/TransformerBlock.js +51 -0
  76. package/dist/layers/WeightStore.d.ts +20 -0
  77. package/dist/layers/WeightStore.js +69 -0
  78. package/dist/loader/load.d.ts +6 -0
  79. package/dist/loader/load.js +2 -0
  80. package/dist/loader/loadHF.d.ts +8 -0
  81. package/dist/loader/loadHF.js +2 -0
  82. package/dist/loader/loadTransformers.d.ts +4 -0
  83. package/dist/loader/loadTransformers.js +2 -0
  84. package/dist/loader/loadZipMeta.d.ts +3 -0
  85. package/dist/loader/loadZipMeta.js +16 -0
  86. package/dist/loader/newZipLoad.d.ts +3 -0
  87. package/dist/loader/newZipLoad.js +2 -0
  88. package/dist/loader/oldZipLoad.d.ts +9 -0
  89. package/dist/loader/oldZipLoad.js +2 -0
  90. package/dist/loader/save.d.ts +16 -0
  91. package/dist/loader/save.js +2 -0
  92. package/dist/loader/types.d.ts +68 -0
  93. package/dist/loader/types.js +0 -0
  94. package/dist/main-D5CbfCiV.js +13500 -0
  95. package/dist/main.d.ts +50 -0
  96. package/dist/main.js +16 -0
  97. package/dist/matMul16-BNfZSnNM.js +81 -0
  98. package/dist/matMulGelu-CPTntosE.js +162 -0
  99. package/dist/models/NanoGPTV1.d.ts +16 -0
  100. package/dist/models/NanoGPTV1.js +2 -0
  101. package/dist/models/NanoGPTV2.d.ts +16 -0
  102. package/dist/models/NanoGPTV2.js +2 -0
  103. package/dist/models/config.d.ts +27 -0
  104. package/dist/models/config.js +37 -0
  105. package/dist/models/factory.d.ts +3 -0
  106. package/dist/models/factory.js +2 -0
  107. package/dist/models/model.d.ts +44 -0
  108. package/dist/models/model.js +2 -0
  109. package/dist/ops/adamAdjust.d.ts +2 -0
  110. package/dist/ops/adamAdjust.js +18 -0
  111. package/dist/ops/adamMoments.d.ts +2 -0
  112. package/dist/ops/adamMoments.js +16 -0
  113. package/dist/ops/add16.d.ts +2 -0
  114. package/dist/ops/add16.js +12 -0
  115. package/dist/ops/appendCache.d.ts +2 -0
  116. package/dist/ops/appendCache.js +25 -0
  117. package/dist/ops/attentionMask.d.ts +2 -0
  118. package/dist/ops/attentionMask.js +16 -0
  119. package/dist/ops/concat16.d.ts +2 -0
  120. package/dist/ops/concat16.js +8 -0
  121. package/dist/ops/cpu/adamAdjust.d.ts +1 -0
  122. package/dist/ops/cpu/adamAdjust.js +16 -0
  123. package/dist/ops/cpu/adamMoments.d.ts +1 -0
  124. package/dist/ops/cpu/adamMoments.js +16 -0
  125. package/dist/ops/cpu/appendCache.d.ts +1 -0
  126. package/dist/ops/cpu/appendCache.js +65 -0
  127. package/dist/ops/cpu/attentionMask.d.ts +1 -0
  128. package/dist/ops/cpu/attentionMask.js +16 -0
  129. package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
  130. package/dist/ops/cpu/fusedSoftmax.js +22 -0
  131. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  132. package/dist/ops/cpu/gatherSub.js +12 -0
  133. package/dist/ops/cpu/gelu.d.ts +1 -0
  134. package/dist/ops/cpu/gelu.js +36 -0
  135. package/dist/ops/cpu/matMul16.d.ts +1 -0
  136. package/dist/ops/cpu/matMul16.js +14 -0
  137. package/dist/ops/cpu/matMulGelu.d.ts +1 -0
  138. package/dist/ops/cpu/matMulGelu.js +41 -0
  139. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  140. package/dist/ops/cpu/matMulMul.js +20 -0
  141. package/dist/ops/cpu/mulDropout.d.ts +1 -0
  142. package/dist/ops/cpu/mulDropout.js +20 -0
  143. package/dist/ops/cpu/normRMS.d.ts +1 -0
  144. package/dist/ops/cpu/normRMS.js +35 -0
  145. package/dist/ops/cpu/qkv.d.ts +5 -0
  146. package/dist/ops/cpu/qkv.js +73 -0
  147. package/dist/ops/cpu/rope.d.ts +6 -0
  148. package/dist/ops/cpu/rope.js +81 -0
  149. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  150. package/dist/ops/cpu/scatterSub.js +12 -0
  151. package/dist/ops/dot16.d.ts +2 -0
  152. package/dist/ops/dot16.js +29 -0
  153. package/dist/ops/dropout.d.ts +2 -0
  154. package/dist/ops/dropout.js +11 -0
  155. package/dist/ops/dropout16.d.ts +2 -0
  156. package/dist/ops/dropout16.js +22 -0
  157. package/dist/ops/gatherSub.d.ts +2 -0
  158. package/dist/ops/gatherSub.js +13 -0
  159. package/dist/ops/gelu.d.ts +3 -0
  160. package/dist/ops/gelu.js +2 -0
  161. package/dist/ops/globalNorm.d.ts +2 -0
  162. package/dist/ops/globalNorm.js +19 -0
  163. package/dist/ops/grads/add16.d.ts +1 -0
  164. package/dist/ops/grads/add16.js +27 -0
  165. package/dist/ops/grads/attentionMask.d.ts +1 -0
  166. package/dist/ops/grads/attentionMask.js +26 -0
  167. package/dist/ops/grads/dropout16.d.ts +1 -0
  168. package/dist/ops/grads/dropout16.js +1 -0
  169. package/dist/ops/grads/gelu.d.ts +2 -0
  170. package/dist/ops/grads/gelu.js +2 -0
  171. package/dist/ops/grads/matMul16.d.ts +2 -0
  172. package/dist/ops/grads/matMul16.js +2 -0
  173. package/dist/ops/grads/matMulGelu.d.ts +1 -0
  174. package/dist/ops/grads/matMulGelu.js +22 -0
  175. package/dist/ops/grads/mul16.d.ts +1 -0
  176. package/dist/ops/grads/mul16.js +1 -0
  177. package/dist/ops/grads/normRMS.d.ts +3 -0
  178. package/dist/ops/grads/normRMS.js +37 -0
  179. package/dist/ops/grads/pack16.d.ts +2 -0
  180. package/dist/ops/grads/pack16.js +2 -0
  181. package/dist/ops/grads/qkv.d.ts +3 -0
  182. package/dist/ops/grads/qkv.js +46 -0
  183. package/dist/ops/grads/rope.d.ts +2 -0
  184. package/dist/ops/grads/rope.js +2 -0
  185. package/dist/ops/grads/softmax16.d.ts +2 -0
  186. package/dist/ops/grads/softmax16.js +23 -0
  187. package/dist/ops/grads/unpack16.d.ts +2 -0
  188. package/dist/ops/grads/unpack16.js +2 -0
  189. package/dist/ops/grads/utils.d.ts +4 -0
  190. package/dist/ops/grads/utils.js +12 -0
  191. package/dist/ops/log.d.ts +0 -0
  192. package/dist/ops/log.js +1 -0
  193. package/dist/ops/matMul16.d.ts +15 -0
  194. package/dist/ops/matMul16.js +2 -0
  195. package/dist/ops/matMulGelu.d.ts +3 -0
  196. package/dist/ops/matMulGelu.js +20 -0
  197. package/dist/ops/matMulMul.d.ts +2 -0
  198. package/dist/ops/matMulMul.js +16 -0
  199. package/dist/ops/mul16.d.ts +2 -0
  200. package/dist/ops/mul16.js +43 -0
  201. package/dist/ops/mulDrop.d.ts +2 -0
  202. package/dist/ops/mulDrop.js +15 -0
  203. package/dist/ops/normRMS.d.ts +2 -0
  204. package/dist/ops/normRMS.js +22 -0
  205. package/dist/ops/pack16.d.ts +2 -0
  206. package/dist/ops/pack16.js +2 -0
  207. package/dist/ops/qkv.d.ts +2 -0
  208. package/dist/ops/qkv.js +16 -0
  209. package/dist/ops/reshape16.d.ts +2 -0
  210. package/dist/ops/reshape16.js +33 -0
  211. package/dist/ops/rope.d.ts +3 -0
  212. package/dist/ops/rope.js +2 -0
  213. package/dist/ops/scatterSub.d.ts +2 -0
  214. package/dist/ops/scatterSub.js +13 -0
  215. package/dist/ops/slice16.d.ts +2 -0
  216. package/dist/ops/slice16.js +11 -0
  217. package/dist/ops/softmax16.d.ts +2 -0
  218. package/dist/ops/softmax16.js +9 -0
  219. package/dist/ops/sub16.d.ts +2 -0
  220. package/dist/ops/sub16.js +11 -0
  221. package/dist/ops/sum16.d.ts +2 -0
  222. package/dist/ops/sum16.js +13 -0
  223. package/dist/ops/transpose16.d.ts +3 -0
  224. package/dist/ops/transpose16.js +32 -0
  225. package/dist/ops/unpack16.d.ts +2 -0
  226. package/dist/ops/unpack16.js +2 -0
  227. package/dist/ops/webgl/adamAdjust.d.ts +1 -0
  228. package/dist/ops/webgl/adamAdjust.js +82 -0
  229. package/dist/ops/webgl/adamMoments.d.ts +1 -0
  230. package/dist/ops/webgl/adamMoments.js +44 -0
  231. package/dist/ops/webgl/appendCache.d.ts +1 -0
  232. package/dist/ops/webgl/appendCache.js +53 -0
  233. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  234. package/dist/ops/webgl/attentionMask.js +64 -0
  235. package/dist/ops/webgl/dropout16.d.ts +1 -0
  236. package/dist/ops/webgl/dropout16.js +12 -0
  237. package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
  238. package/dist/ops/webgl/fusedSoftmax.js +70 -0
  239. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  240. package/dist/ops/webgl/gatherSub.js +28 -0
  241. package/dist/ops/webgl/gelu.d.ts +2 -0
  242. package/dist/ops/webgl/gelu.js +48 -0
  243. package/dist/ops/webgl/log.d.ts +17 -0
  244. package/dist/ops/webgl/log.js +14 -0
  245. package/dist/ops/webgl/matMul16.d.ts +1 -0
  246. package/dist/ops/webgl/matMul16.js +37 -0
  247. package/dist/ops/webgl/matMulGelu.d.ts +21 -0
  248. package/dist/ops/webgl/matMulGelu.js +2 -0
  249. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  250. package/dist/ops/webgl/matMulMul.js +24 -0
  251. package/dist/ops/webgl/mulDropout.d.ts +1 -0
  252. package/dist/ops/webgl/mulDropout.js +32 -0
  253. package/dist/ops/webgl/normRMS.d.ts +1 -0
  254. package/dist/ops/webgl/normRMS.js +114 -0
  255. package/dist/ops/webgl/qkv.d.ts +1 -0
  256. package/dist/ops/webgl/qkv.js +54 -0
  257. package/dist/ops/webgl/rope.d.ts +1 -0
  258. package/dist/ops/webgl/rope.js +72 -0
  259. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  260. package/dist/ops/webgl/scatterSub.js +28 -0
  261. package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
  262. package/dist/ops/webgpu/adamAdjust.js +77 -0
  263. package/dist/ops/webgpu/adamMoments.d.ts +1 -0
  264. package/dist/ops/webgpu/adamMoments.js +76 -0
  265. package/dist/ops/webgpu/add16.d.ts +1 -0
  266. package/dist/ops/webgpu/add16.js +14 -0
  267. package/dist/ops/webgpu/appendCache.d.ts +1 -0
  268. package/dist/ops/webgpu/appendCache.js +130 -0
  269. package/dist/ops/webgpu/attentionMask.d.ts +1 -0
  270. package/dist/ops/webgpu/attentionMask.js +42 -0
  271. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  272. package/dist/ops/webgpu/attentionMask32_program.js +62 -0
  273. package/dist/ops/webgpu/clipScale.d.ts +1 -0
  274. package/dist/ops/webgpu/clipScale.js +45 -0
  275. package/dist/ops/webgpu/concat16.d.ts +19 -0
  276. package/dist/ops/webgpu/concat16.js +111 -0
  277. package/dist/ops/webgpu/dropout16.d.ts +1 -0
  278. package/dist/ops/webgpu/dropout16.js +59 -0
  279. package/dist/ops/webgpu/gatherSub.d.ts +1 -0
  280. package/dist/ops/webgpu/gatherSub.js +52 -0
  281. package/dist/ops/webgpu/gelu.d.ts +14 -0
  282. package/dist/ops/webgpu/gelu.js +147 -0
  283. package/dist/ops/webgpu/index.d.ts +0 -0
  284. package/dist/ops/webgpu/index.js +26 -0
  285. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  286. package/dist/ops/webgpu/matMul16.js +70 -0
  287. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  288. package/dist/ops/webgpu/matMul16_program.js +303 -0
  289. package/dist/ops/webgpu/mul16.d.ts +1 -0
  290. package/dist/ops/webgpu/mul16.js +14 -0
  291. package/dist/ops/webgpu/norm2.d.ts +1 -0
  292. package/dist/ops/webgpu/norm2.js +46 -0
  293. package/dist/ops/webgpu/normRMS.d.ts +1 -0
  294. package/dist/ops/webgpu/normRMS.js +26 -0
  295. package/dist/ops/webgpu/normRMS16_program.d.ts +10 -0
  296. package/dist/ops/webgpu/normRMS16_program.js +28 -0
  297. package/dist/ops/webgpu/normRMS32_program.d.ts +10 -0
  298. package/dist/ops/webgpu/normRMS32_program.js +28 -0
  299. package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
  300. package/dist/ops/webgpu/normRMSGrad.js +225 -0
  301. package/dist/ops/webgpu/pack16.d.ts +1 -0
  302. package/dist/ops/webgpu/pack16.js +21 -0
  303. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  304. package/dist/ops/webgpu/pack16_program.js +93 -0
  305. package/dist/ops/webgpu/qkv.d.ts +1 -0
  306. package/dist/ops/webgpu/qkv.js +64 -0
  307. package/dist/ops/webgpu/rope.d.ts +1 -0
  308. package/dist/ops/webgpu/rope.js +163 -0
  309. package/dist/ops/webgpu/scatterSub.d.ts +1 -0
  310. package/dist/ops/webgpu/scatterSub.js +53 -0
  311. package/dist/ops/webgpu/slice16.d.ts +7 -0
  312. package/dist/ops/webgpu/slice16.js +74 -0
  313. package/dist/ops/webgpu/softmax16.d.ts +17 -0
  314. package/dist/ops/webgpu/softmax16.js +18 -0
  315. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  316. package/dist/ops/webgpu/softmax16_program.js +89 -0
  317. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  318. package/dist/ops/webgpu/softmax16_subgroup_program.js +70 -0
  319. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  320. package/dist/ops/webgpu/softmax16grad.js +31 -0
  321. package/dist/ops/webgpu/sub16.d.ts +1 -0
  322. package/dist/ops/webgpu/sub16.js +14 -0
  323. package/dist/ops/webgpu/sum16.d.ts +1 -0
  324. package/dist/ops/webgpu/sum16.js +29 -0
  325. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  326. package/dist/ops/webgpu/transpose16.js +37 -0
  327. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  328. package/dist/ops/webgpu/transpose16_program.js +51 -0
  329. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  330. package/dist/ops/webgpu/transpose16_shared_program.js +79 -0
  331. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  332. package/dist/ops/webgpu/unpack16.js +60 -0
  333. package/dist/ops/webgpu/utils/binary_op.d.ts +35 -0
  334. package/dist/ops/webgpu/utils/binary_op.js +141 -0
  335. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  336. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  337. package/dist/ops/webgpu/utils/reductions.d.ts +43 -0
  338. package/dist/ops/webgpu/utils/reductions.js +263 -0
  339. package/dist/pack16-Ck-spx_F.js +39 -0
  340. package/dist/patches/webgpu_backend.d.ts +18 -0
  341. package/dist/patches/webgpu_backend.js +43 -0
  342. package/dist/patches/webgpu_base.d.ts +21 -0
  343. package/dist/patches/webgpu_base.js +22 -0
  344. package/dist/patches/webgpu_program.d.ts +36 -0
  345. package/dist/patches/webgpu_program.js +293 -0
  346. package/dist/pdf-UoDqCYzz.js +16726 -0
  347. package/dist/picomatch-3tUnMMbd.js +1063 -0
  348. package/dist/rope-CbeGlsV8.js +25 -0
  349. package/dist/selu_util-zkAx5doH.js +24 -0
  350. package/dist/shared-D1coEFea.js +1314 -0
  351. package/dist/shared-DOgWaqvL.js +5 -0
  352. package/dist/slice_util-Dgb3ANWI.js +208 -0
  353. package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
  354. package/dist/tokeniser/BaseTokeniser.d.ts +33 -0
  355. package/dist/tokeniser/BaseTokeniser.js +2 -0
  356. package/dist/tokeniser/CharTokeniser.d.ts +24 -0
  357. package/dist/tokeniser/CharTokeniser.js +92 -0
  358. package/dist/tokeniser/bpe.d.ts +28 -0
  359. package/dist/tokeniser/bpe.js +170 -0
  360. package/dist/tokeniser/messages.d.ts +61 -0
  361. package/dist/tokeniser/messages.js +0 -0
  362. package/dist/tokeniser/type.d.ts +34 -0
  363. package/dist/tokeniser/type.js +0 -0
  364. package/dist/training/AdamW.d.ts +36 -0
  365. package/dist/training/AdamW.js +128 -0
  366. package/dist/training/BasicTrainer.d.ts +63 -0
  367. package/dist/training/BasicTrainer.js +265 -0
  368. package/dist/training/DatasetBuilder.d.ts +26 -0
  369. package/dist/training/DatasetBuilder.js +2 -0
  370. package/dist/training/Evaluator.d.ts +19 -0
  371. package/dist/training/Evaluator.js +48 -0
  372. package/dist/training/LRScheduler.d.ts +12 -0
  373. package/dist/training/LRScheduler.js +38 -0
  374. package/dist/training/PreTrainer.d.ts +11 -0
  375. package/dist/training/PreTrainer.js +22 -0
  376. package/dist/training/SFTTrainer.d.ts +12 -0
  377. package/dist/training/SFTTrainer.js +24 -0
  378. package/dist/training/loss.d.ts +3 -0
  379. package/dist/training/loss.js +19 -0
  380. package/dist/training/orthoGrad.d.ts +2 -0
  381. package/dist/training/orthoGrad.js +10 -0
  382. package/dist/training/sparseCrossEntropy.d.ts +7 -0
  383. package/dist/training/sparseCrossEntropy.js +47 -0
  384. package/dist/training/tasks/ConversationTask.d.ts +18 -0
  385. package/dist/training/tasks/ConversationTask.js +38 -0
  386. package/dist/training/tasks/PretrainingTask.d.ts +17 -0
  387. package/dist/training/tasks/PretrainingTask.js +42 -0
  388. package/dist/training/tasks/StartSentenceTask.d.ts +18 -0
  389. package/dist/training/tasks/StartSentenceTask.js +45 -0
  390. package/dist/training/tasks/Task.d.ts +22 -0
  391. package/dist/training/tasks/Task.js +55 -0
  392. package/dist/training/tasks/splitter.d.ts +5 -0
  393. package/dist/training/tasks/splitter.js +18 -0
  394. package/dist/training/types.d.ts +78 -0
  395. package/dist/training/types.js +0 -0
  396. package/dist/training/validation.d.ts +17 -0
  397. package/dist/training/validation.js +2 -0
  398. package/dist/utilities/arrayClose.d.ts +1 -0
  399. package/dist/utilities/arrayClose.js +16 -0
  400. package/dist/utilities/datasetID.d.ts +2 -0
  401. package/dist/utilities/datasetID.js +18 -0
  402. package/dist/utilities/dummy.d.ts +9 -0
  403. package/dist/utilities/dummy.js +36 -0
  404. package/dist/utilities/multinomialCPU.d.ts +2 -0
  405. package/dist/utilities/multinomialCPU.js +9 -0
  406. package/dist/utilities/naming.d.ts +4 -0
  407. package/dist/utilities/naming.js +0 -0
  408. package/dist/utilities/packed.d.ts +4 -0
  409. package/dist/utilities/packed.js +13 -0
  410. package/dist/utilities/parameters.d.ts +11 -0
  411. package/dist/utilities/parameters.js +38 -0
  412. package/dist/utilities/performance.d.ts +2 -0
  413. package/dist/utilities/performance.js +16 -0
  414. package/dist/utilities/profile.d.ts +17 -0
  415. package/dist/utilities/profile.js +33 -0
  416. package/dist/utilities/safetensors.d.ts +3 -0
  417. package/dist/utilities/safetensors.js +53 -0
  418. package/dist/utilities/sentences.d.ts +5 -0
  419. package/dist/utilities/sentences.js +32 -0
  420. package/dist/utilities/tokenParse.d.ts +1 -0
  421. package/dist/utilities/tokenParse.js +17 -0
  422. package/dist/utilities/topP.d.ts +1 -0
  423. package/dist/utilities/topP.js +12 -0
  424. package/dist/utilities/waitForModel.d.ts +2 -0
  425. package/dist/utilities/waitForModel.js +12 -0
  426. package/dist/utilities/weights.d.ts +12 -0
  427. package/dist/utilities/weights.js +40 -0
  428. package/dist/utilities/yielder.d.ts +1 -0
  429. package/dist/utilities/yielder.js +7 -0
  430. package/dist/webgpu-Dt7BMzWz.js +525 -0
  431. package/dist/webgpu_program-WOyIVMlZ.js +392 -0
  432. package/dist/webgpu_util-B_F3SShA.js +106 -0
  433. package/package.json +1 -1
@@ -0,0 +1,92 @@
1
+ import { yieldIfNeeded as e } from "../utilities/yielder.js";
2
+ import { n as t, t as n } from "../BaseTokeniser-DSg9zcYq.js";
3
+ //#region lib/tokeniser/CharTokeniser.ts
4
+ var r = ["<eos>", "<unk>"], i = class extends n {
5
+ vocabSize = 0;
6
+ eosToken = 0;
7
+ bosToken = 0;
8
+ unkToken = 0;
9
+ vocab = [];
10
+ cache = /* @__PURE__ */ new Map();
11
+ _trained = !1;
12
+ constructor(e) {
13
+ if (super(), Array.isArray(e)) {
14
+ if (this.vocab = e, this.vocab.length > 0) this.vocabSize = this.vocab.length, t.forEach((e) => {
15
+ let t = this.vocab.indexOf(e);
16
+ t !== -1 && this.addSpecialToken(e, t);
17
+ }), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex("") ?? -1, this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<unk>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("<pad>")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf("_")), this.unkToken === -1 && (this.unkToken = this.vocab.indexOf(" ")), this.unkToken === -1 && (this.unkToken = this.eosToken), this.vocab = this.vocab.map((e) => e === "<pad>" ? "" : e), this.vocab.forEach((e, t) => {
18
+ this.cache.set(e, t);
19
+ });
20
+ else throw Error("Vocab cannot be empty");
21
+ this._trained = !0;
22
+ } else this.vocabSize = e, this.vocab = Array(this.vocabSize).fill(""), this.addSpecialTokens(), this.eosToken = this.getSpecialTokenIndex("<eos>"), this.bosToken = this.getSpecialTokenIndex("<bos>") ?? this.eosToken, this.unkToken = this.getSpecialTokenIndex(""), this.vocab.forEach((e, t) => {
23
+ this.cache.set(e, t);
24
+ }), this.cache.set("", this.unkToken);
25
+ }
26
+ addToken(e, t) {
27
+ if (this.cache.has(e)) return this.cache.get(e);
28
+ let n;
29
+ if (t === void 0 ? (n = this.vocab.indexOf("", this.unkToken + 1), n === -1 && (n = this.vocabSize)) : n = t, n >= this.vocabSize) throw Error("Vocab size exceeded");
30
+ return this.vocab[n] = e, this.cache.set(e, n), n;
31
+ }
32
+ get trained() {
33
+ return this.vocab.length === this.vocabSize && this._trained;
34
+ }
35
+ destroy() {
36
+ this.cache.clear(), this.vocab = [];
37
+ }
38
+ async train(t, n, i) {
39
+ this.datasetID = i;
40
+ let a = /* @__PURE__ */ new Set(), o = performance.now();
41
+ for (let r of t) r.forEach((e) => {
42
+ for (let t of e.content) a.add(t);
43
+ }), o = await e(o, n, 0);
44
+ let s = Array.from(a), c = this.vocab.indexOf("", this.unkToken + 1), l = this.vocabSize - r.length;
45
+ if (c === -1) return this.generateID(), this.vocabSize;
46
+ if (this._trained = !0, s.length > l) {
47
+ let e = /* @__PURE__ */ new Map();
48
+ t.forEach((t) => {
49
+ t.forEach((t) => {
50
+ for (let n of t.content) e.set(n, (e.get(n) || 0) + 1);
51
+ });
52
+ }), s.sort((t, n) => (e.get(t) || 0) - (e.get(n) || 0)), s.splice(0, s.length - l);
53
+ }
54
+ let u = c;
55
+ if (u !== -1) {
56
+ let e = new Set(this.vocab);
57
+ for (let t of s) if (!e.has(t) && (this.vocab[u] = t, e.add(t), u = this.vocab.indexOf("", u + 1), u === -1)) break;
58
+ }
59
+ return this.cache.clear(), this.vocab.forEach((e, t) => {
60
+ this.cache.set(e, t);
61
+ }), this.generateID(), this.emit("trainStatus", "trained"), this.vocabSize;
62
+ }
63
+ tokenise(e, t) {
64
+ if (!this.trained) throw Error("Tokeniser not trained");
65
+ return e.map((e) => t ? e.split("").map((e) => this.cache.get(e) ?? this.unkToken) : e.split("").map((e) => {
66
+ let t = this.cache.get(e);
67
+ return t === void 0 ? "" : this.vocab[t];
68
+ }));
69
+ }
70
+ detokenise(e) {
71
+ return e.map((e) => Array.from(e).map((e) => this.vocab[e] || "").join(""));
72
+ }
73
+ encode(e) {
74
+ return this.tokenise([e], !0)[0];
75
+ }
76
+ decode(e) {
77
+ return this.detokenise([e])[0];
78
+ }
79
+ getVocab() {
80
+ return this.vocab;
81
+ }
82
+ getMerges() {
83
+ return [];
84
+ }
85
+ async createTrainingData(e, t = 5) {
86
+ let n = await this.tokenise(e, !0), r = [], i = [];
87
+ for (let e = 0; e < n.length - t; e++) r.push(...n[e].slice(0, t)), i.push(n[e + 1][0]);
88
+ return [r, i];
89
+ }
90
+ };
91
+ //#endregion
92
+ export { i as default };
@@ -0,0 +1,28 @@
1
+ import { default as BaseTokeniser } from './BaseTokeniser';
2
+ import { Conversation } from './type';
3
+ export default class BPETokeniser extends BaseTokeniser {
4
+ private targetSize;
5
+ private vocab;
6
+ private vocabIndex;
7
+ private merges;
8
+ private pretokenMap;
9
+ constructor(vocabSize: number);
10
+ constructor(vocab: string[], merges?: [string, string][]);
11
+ addToken(token: string, index?: number): number;
12
+ destroy(): void;
13
+ get trained(): boolean;
14
+ get vocabSize(): number;
15
+ get eosToken(): number;
16
+ get bosToken(): number;
17
+ get unkToken(): number;
18
+ train(text?: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
19
+ getVocab(): string[];
20
+ getMerges(): [string, string][];
21
+ private tokeniseWord;
22
+ private tokeniseStrings;
23
+ tokenise(text: string[], numeric: true): number[][];
24
+ tokenise(text: string[]): string[][];
25
+ detokenise(tokens: number[][]): string[];
26
+ encode(text: string): number[];
27
+ decode(tokens: number[]): string;
28
+ }
@@ -0,0 +1,170 @@
1
+ import { yieldIfNeeded as e } from "../utilities/yielder.js";
2
+ import { n as t, t as n } from "../BaseTokeniser-DSg9zcYq.js";
3
+ import r from "../utilities/tokenParse.js";
4
+ //#region lib/tokeniser/bpe.ts
5
+ function i(e, t) {
6
+ return `${e}-::-${t}`;
7
+ }
8
+ function a(e) {
9
+ let t = /* @__PURE__ */ new Map();
10
+ for (let n = 0; n < e.length; n++) {
11
+ let r = e[n];
12
+ for (let e = 0; e < r.length - 1; e++) {
13
+ let a = i(r[e], r[e + 1]), o = t.get(a) || {
14
+ a: r[e],
15
+ b: r[e + 1],
16
+ count: 0,
17
+ instances: /* @__PURE__ */ new Set()
18
+ };
19
+ o.count += 1, o.instances.add(n), t.set(a, o);
20
+ }
21
+ }
22
+ return {
23
+ pairs: t,
24
+ tokens: e
25
+ };
26
+ }
27
+ function o(e, t, n, r, a) {
28
+ let o = i(t, n);
29
+ if (e.pairs.has(o)) {
30
+ let t = e.pairs.get(o);
31
+ t.count += a, a > 0 ? t.instances.add(r) : t.count <= 0 ? e.pairs.delete(o) : t.instances.delete(r);
32
+ } else e.pairs.set(o, {
33
+ a: t,
34
+ b: n,
35
+ count: a,
36
+ instances: new Set([r])
37
+ });
38
+ }
39
+ function s(e) {
40
+ let t = null, n = 0;
41
+ for (let r of e.pairs.values()) r.count > n && (n = r.count, t = r);
42
+ return t;
43
+ }
44
+ function c(e, t) {
45
+ return e.map((e) => {
46
+ let n = [];
47
+ for (let r = 0; r < e.length; r++) r < e.length - 1 && e[r] === t[0] && e[r + 1] === t[1] ? (n.push(t[0] + t[1]), r++) : n.push(e[r]);
48
+ return n;
49
+ });
50
+ }
51
+ function l(e, t) {
52
+ t.instances.forEach((n) => {
53
+ let r = e.tokens[n], i = [];
54
+ for (let a = 0; a < r.length; a++) if (a < r.length - 1 && r[a] === t.a && r[a + 1] === t.b) {
55
+ let s = t.a + t.b;
56
+ i.push(s), a > 0 && (o(e, r[a - 1], t.a, n, -1), o(e, r[a - 1], s, n, 1)), a++, a < r.length - 1 && (o(e, t.b, r[a + 1], n, -1), o(e, s, r[a + 1], n, 1));
57
+ } else i.push(r[a]);
58
+ e.tokens[n] = i;
59
+ }), e.pairs.delete(i(t.a, t.b));
60
+ }
61
+ var u = class extends n {
62
+ targetSize;
63
+ vocab = /* @__PURE__ */ new Set();
64
+ vocabIndex = /* @__PURE__ */ new Map();
65
+ merges = [];
66
+ pretokenMap = /* @__PURE__ */ new Map();
67
+ constructor(e, n) {
68
+ super(), Array.isArray(e) ? (e.forEach((e, t) => {
69
+ this.vocab.add(e), this.vocabIndex.set(e, t);
70
+ }), n && (this.merges = n), this.targetSize = e.length, t.forEach((t) => {
71
+ let n = e.indexOf(t);
72
+ n !== -1 && this.addSpecialToken(t, n);
73
+ })) : (this.addSpecialTokens(), this.targetSize = e);
74
+ }
75
+ addToken(e, t) {
76
+ if (this.vocab.has(e)) return this.vocabIndex.get(e);
77
+ {
78
+ this.vocab.add(e);
79
+ let n = t === void 0 ? this.vocab.size - 1 : t;
80
+ return this.vocabIndex.set(e, n), n;
81
+ }
82
+ }
83
+ destroy() {
84
+ this.vocab.clear(), this.vocabIndex.clear(), this.merges = [], this.pretokenMap.clear();
85
+ }
86
+ get trained() {
87
+ return this.vocab.size > t.length && this.vocab.size <= this.targetSize;
88
+ }
89
+ get vocabSize() {
90
+ return this.vocab.size;
91
+ }
92
+ get eosToken() {
93
+ return this.vocabIndex.get("<eos>") ?? 0;
94
+ }
95
+ get bosToken() {
96
+ return this.vocabIndex.get("<bos>") ?? 0;
97
+ }
98
+ get unkToken() {
99
+ return this.vocabIndex.get("") ?? 1;
100
+ }
101
+ async train(t = [], n, i) {
102
+ this.datasetID = i;
103
+ let o = performance.now(), c = Array(t.length);
104
+ for (let i = 0; i < t.length; i++) {
105
+ let a = t[i], s = Array(a.length);
106
+ for (let e = 0; e < a.length; e++) s[e] = r(a[e].content);
107
+ o = await e(o, n, this.vocab.size), c[i] = s;
108
+ }
109
+ let u = c.flat(2), d = new Set(u);
110
+ this.vocab = /* @__PURE__ */ new Set(), this.pretokenMap.clear(), this.merges = [], this.addSpecialTokens();
111
+ let f = Array.from(d), p = f.map((e) => Array.from(e).map((e) => (this.vocab.add(e), e))), m = a(p);
112
+ if (o = await e(o, n, this.vocab.size), this.vocab.size >= this.targetSize) {
113
+ console.warn("Initial vocab size is greater than or equal to target size. No merges will be performed.");
114
+ let e = /* @__PURE__ */ new Map();
115
+ u.forEach((t) => {
116
+ Array.from(t).forEach((t) => {
117
+ e.set(t, (e.get(t) || 0) + 1);
118
+ });
119
+ });
120
+ let t = Array.from(e.entries()).sort((e, t) => t[1] - e[1]);
121
+ this.vocab = /* @__PURE__ */ new Set(), this.addSpecialTokens(), t.slice(0, this.targetSize - this.vocab.size).map(([e]) => e).forEach((e) => this.vocab.add(e)), this.vocabIndex.clear();
122
+ let n = 0;
123
+ for (let e of this.vocab.keys()) this.vocabIndex.set(e, n++);
124
+ return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
125
+ }
126
+ for (; this.vocab.size < this.targetSize && this.merges.length < this.targetSize;) {
127
+ let t = s(m);
128
+ if (!t) break;
129
+ this.merges.push([t.a, t.b]), this.vocab.add(t.a + t.b), l(m, t), o = await e(o, n, this.vocab.size);
130
+ }
131
+ f.forEach((e, t) => {
132
+ let n = p[t];
133
+ this.pretokenMap.set(e, n);
134
+ }), this.vocabIndex.clear();
135
+ let h = 0;
136
+ for (let e of this.vocab.keys()) this.vocabIndex.set(e, h++);
137
+ return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
138
+ }
139
+ getVocab() {
140
+ return Array.from(this.vocab);
141
+ }
142
+ getMerges() {
143
+ return this.merges;
144
+ }
145
+ tokeniseWord(e) {
146
+ let t = Array.from(e);
147
+ return this.merges.forEach((e) => {
148
+ t = c([t], e)[0];
149
+ }), this.pretokenMap.set(e, t), t;
150
+ }
151
+ tokeniseStrings(e) {
152
+ return e.map((e) => r(e).map((e) => this.pretokenMap.has(e) ? this.pretokenMap.get(e) : this.tokeniseWord(e)).flat(1));
153
+ }
154
+ tokenise(e, t) {
155
+ let n = this.tokeniseStrings(e);
156
+ return t ? n.map((e) => e.map((e) => this.vocabIndex.get(e) ?? this.unkToken)) : n.map((e) => e.map((e) => this.vocab.has(e) ? e : ""));
157
+ }
158
+ detokenise(e) {
159
+ let t = this.getVocab();
160
+ return e.map((e) => e.map((e) => t[e]).join(""));
161
+ }
162
+ encode(e) {
163
+ return this.tokenise([e], !0)[0];
164
+ }
165
+ decode(e) {
166
+ return this.detokenise([e])[0];
167
+ }
168
+ };
169
+ //#endregion
170
+ export { u as default };
@@ -0,0 +1,61 @@
1
+ interface TrainMessage {
2
+ type: 'train';
3
+ id: number;
4
+ text: string[];
5
+ vocabSize: number;
6
+ }
7
+ interface TrainResponse {
8
+ type: 'trainResponse';
9
+ id: number;
10
+ vocabSize: number;
11
+ }
12
+ interface TrainStatusMessage {
13
+ type: 'trainStatus';
14
+ id: number;
15
+ progress: number;
16
+ vocabSize: number;
17
+ }
18
+ interface TokeniseMessage {
19
+ type: 'tokenise';
20
+ id: number;
21
+ numeric?: boolean;
22
+ text: string[];
23
+ }
24
+ interface TokeniseResponse {
25
+ type: 'tokeniseResponse';
26
+ id: number;
27
+ numeric: boolean;
28
+ tokens: string[][] | number[][];
29
+ }
30
+ interface DetokeniseMessage {
31
+ type: 'detokenise';
32
+ id: number;
33
+ tokens: number[][];
34
+ }
35
+ interface DetokeniseResponse {
36
+ type: 'detokeniseResponse';
37
+ id: number;
38
+ text: string[];
39
+ }
40
+ interface TokensMessage {
41
+ type: 'tokens';
42
+ id: number;
43
+ }
44
+ interface TokensResponse {
45
+ type: 'tokensResponse';
46
+ id: number;
47
+ tokens: string[];
48
+ }
49
+ interface BuildTrainingDataMessage {
50
+ type: 'buildTrainingData';
51
+ id: number;
52
+ text: string[];
53
+ windowSize: number;
54
+ }
55
+ interface BuildTrainingDataResponse {
56
+ type: 'buildTrainingDataResponse';
57
+ id: number;
58
+ trainingData: [number[], number[]];
59
+ }
60
+ export type TokeniserMessage = TrainMessage | TrainResponse | TrainStatusMessage | TokeniseMessage | DetokeniseMessage | TokeniseResponse | DetokeniseResponse | TokensMessage | TokensResponse | BuildTrainingDataMessage | BuildTrainingDataResponse;
61
+ export {};
File without changes
@@ -0,0 +1,34 @@
1
+ import { default as EE } from 'eventemitter3';
2
+ export type Roles = 'user' | 'assistant' | 'system' | 'text';
3
+ export interface Conversation {
4
+ role: Roles;
5
+ content: string;
6
+ }
7
+ export interface ITokeniser extends EE<'trainStatus'> {
8
+ id: string;
9
+ datasetID?: string;
10
+ train(text: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
11
+ getVocab(): string[];
12
+ getMerges(): [string, string][];
13
+ destroy(): void;
14
+ encode(text: string): number[];
15
+ encodeConversation(conversation: Conversation[], completion?: boolean): number[];
16
+ encodeConversation(conversation: Conversation[], completion: boolean, masking: boolean): {
17
+ tokens: number[];
18
+ mask: boolean[];
19
+ };
20
+ encodeConversation(conversation: Conversation[], completion?: boolean, masking?: boolean): number[] | {
21
+ tokens: number[];
22
+ mask: boolean[];
23
+ };
24
+ encodeSequence(text: string): number[];
25
+ encodeAsSequence(conversation: Conversation[], completion?: boolean): number[];
26
+ decode(tokens: number[] | Uint16Array): string;
27
+ decodeConversation(tokens: number[] | Uint16Array): Conversation[];
28
+ vocabSize: number;
29
+ eosToken: number;
30
+ bosToken: number;
31
+ trained: boolean;
32
+ getSpecialTokenIndex(token: string): number | undefined;
33
+ isSpecialToken(index: number): boolean;
34
+ }
File without changes
@@ -0,0 +1,36 @@
1
+ import { Optimizer, Tensor } from '@tensorflow/tfjs-core';
2
+ import { ConfigDict, Serializable, SerializableConstructor } from '@tensorflow/tfjs-core/dist/serialization';
3
+ import { NamedTensor, NamedVariableMap } from '@tensorflow/tfjs-core/dist/tensor_types';
4
+ import { default as LRScheduler } from './LRScheduler';
5
+ import { AdamWOptimizerConfig } from './types';
6
+ export declare class AdamWOptimizer extends Optimizer {
7
+ private config;
8
+ readonly className = "AdamW";
9
+ private accBeta1;
10
+ private accBeta2;
11
+ private accumulatedMoments;
12
+ protected learningRate: number;
13
+ protected beta1: number;
14
+ protected beta2: number;
15
+ protected lossScaling: number;
16
+ protected weightDecay: number;
17
+ protected epsilon: number | null;
18
+ protected lrScheduler: LRScheduler;
19
+ protected clipNorm?: number;
20
+ protected orthGradEpsilon: number;
21
+ protected orthGrad: boolean;
22
+ constructor(config: AdamWOptimizerConfig);
23
+ get lr(): number;
24
+ saveMoments(): Promise<ArrayBuffer>;
25
+ loadMoments(momentData: ArrayBuffer): Promise<void>;
26
+ serializeConfig(): AdamWOptimizerConfig;
27
+ private orthogonalizeGradient;
28
+ updateConfig(newConfig: Partial<AdamWOptimizerConfig>): void;
29
+ applyGradients(variableGradients: NamedVariableMap | NamedTensor[]): Tensor;
30
+ dispose(): void;
31
+ getWeights(): Promise<NamedTensor[]>;
32
+ setWeights(weightValues: NamedTensor[]): Promise<void>;
33
+ getConfig(): ConfigDict;
34
+ /** @nocollapse */
35
+ static fromConfig<T extends Serializable>(cls: SerializableConstructor<T>, config: ConfigDict): T;
36
+ }
@@ -0,0 +1,128 @@
1
+ import { _n as e, c as t, di as n, ii as r, kt as i, ni as a } from "../dist-BewPQWjc.js";
2
+ import { load_safetensors as o, save_safetensors as s } from "../utilities/safetensors.js";
3
+ import { adamAdjust as c } from "../ops/adamAdjust.js";
4
+ import { adamMoments as l } from "../ops/adamMoments.js";
5
+ import u from "./LRScheduler.js";
6
+ import { clipScale as d } from "../ops/globalNorm.js";
7
+ //#region lib/training/AdamW.ts
8
+ var f = class extends t {
9
+ config;
10
+ className = "AdamW";
11
+ accBeta1 = 0;
12
+ accBeta2 = 0;
13
+ accumulatedMoments = [];
14
+ learningRate;
15
+ beta1;
16
+ beta2;
17
+ lossScaling;
18
+ weightDecay;
19
+ epsilon = null;
20
+ lrScheduler;
21
+ clipNorm;
22
+ orthGradEpsilon = 1e-30;
23
+ orthGrad;
24
+ constructor(e) {
25
+ super(), this.config = e, this.accBeta1 = e.accBeta1 ?? e.beta1, this.accBeta2 = e.accBeta2 ?? e.beta2, this.learningRate = e.learningRate, this.beta1 = e.beta1, this.beta2 = e.beta2, this.weightDecay = e.weightDecay, this.lossScaling = e.lossScaling, this.clipNorm = e.clipNorm, this.orthGrad = e.orthoGrad ?? !1, e.epsilon === null || e.epsilon === void 0 ? this.epsilon = r().backend.epsilon() : this.epsilon = e.epsilon, this.lrScheduler = new u(e.learningRate, e);
26
+ }
27
+ get lr() {
28
+ return this.learningRate;
29
+ }
30
+ saveMoments() {
31
+ let e = {};
32
+ return this.accumulatedMoments.forEach((t) => {
33
+ e[t.originalName] = t.variable;
34
+ }), s(e);
35
+ }
36
+ async loadMoments(e) {
37
+ let t = await o(e);
38
+ Object.entries(t).forEach(([e, t]) => {
39
+ let n = t.variable(!1);
40
+ this.accumulatedMoments.push({
41
+ originalName: e,
42
+ variable: n
43
+ });
44
+ });
45
+ }
46
+ serializeConfig() {
47
+ return {
48
+ learningRate: this.learningRate,
49
+ beta1: this.beta1,
50
+ beta2: this.beta2,
51
+ accBeta1: this.accBeta1,
52
+ accBeta2: this.accBeta2,
53
+ epsilon: this.epsilon ?? void 0,
54
+ weightDecay: this.weightDecay,
55
+ lossScaling: this.lossScaling,
56
+ clipNorm: this.clipNorm,
57
+ orthoGrad: this.orthGrad,
58
+ ...this.lrScheduler.serializeConfig()
59
+ };
60
+ }
61
+ orthogonalizeGradient(e, t) {
62
+ return n(() => {
63
+ let n = e.reshape([-1]), r = t.reshape([-1]), i = n.mul(n).sum().add(this.orthGradEpsilon), a = n.mul(r).sum().div(i), o = r.sub(n.mul(a)), s = r.norm(), c = o.norm().add(this.orthGradEpsilon);
64
+ return o.mul(s.div(c)).reshape(t.shape);
65
+ });
66
+ }
67
+ updateConfig(e) {
68
+ let t = {
69
+ ...this.config,
70
+ ...e
71
+ };
72
+ this.learningRate = t.learningRate, this.beta1 = t.beta1, this.beta2 = t.beta2, this.weightDecay = t.weightDecay, this.lossScaling = t.lossScaling, this.epsilon = t.epsilon ?? this.epsilon, this.clipNorm = t.clipNorm, this.lrScheduler.updateConfig(t, t.learningRate);
73
+ }
74
+ applyGradients(t) {
75
+ let a = this.lrScheduler.getNextLR();
76
+ this.learningRate = a;
77
+ let o = Array.isArray(t) ? t.map((e) => e.name) : Object.keys(t), s = n(() => {
78
+ let a = 1 - this.accBeta1, s = 1 - this.accBeta2, u;
79
+ return u = this.clipNorm === void 0 ? e(1 / this.lossScaling) : d(o.map((e, n) => Array.isArray(t) ? t[n].tensor : t[e]), 1 / this.lossScaling, this.clipNorm), o.forEach((e, o) => {
80
+ let d = r().registeredVariables[e];
81
+ this.accumulatedMoments[o] ?? (this.accumulatedMoments[o] = {
82
+ originalName: `${e}/m`,
83
+ variable: n(() => i([...d.shape, 2]).variable(!1))
84
+ });
85
+ let f = Array.isArray(t) ? t[o].tensor : t[e];
86
+ if (f == null) return;
87
+ let p = this.orthGrad ? this.orthogonalizeGradient(d, f) : f, m = this.accumulatedMoments[o].variable, h = l(m, p, this.beta1, this.beta2, u);
88
+ m.assign(h), this.orthGrad && p.dispose();
89
+ let g = c(h, d, a, s, this.epsilon ?? 1e-8, this.learningRate, d.shape.length > 1 ? this.weightDecay : 0);
90
+ d.assign(g);
91
+ }), this.accBeta1 *= this.beta1, this.accBeta2 *= this.beta2, u;
92
+ });
93
+ return this.incrementIterations(), s;
94
+ }
95
+ dispose() {
96
+ this.accumulatedMoments != null && a(this.accumulatedMoments.map((e) => e.variable));
97
+ }
98
+ async getWeights() {
99
+ let e = [...this.accumulatedMoments];
100
+ return [await this.saveIterations()].concat(e.map((e) => ({
101
+ name: e.originalName,
102
+ tensor: e.variable
103
+ })));
104
+ }
105
+ async setWeights(e) {
106
+ e = await this.extractIterations(e), n(() => {
107
+ this.accBeta1 = this.beta1 ** +(this.iterations_ + 1), this.accBeta2 = this.beta2 ** +(this.iterations_ + 1);
108
+ });
109
+ let t = e.length / 2;
110
+ this.accumulatedMoments = e.slice(0, t).map((e) => ({
111
+ originalName: e.name,
112
+ variable: e.tensor.variable(!1)
113
+ }));
114
+ }
115
+ getConfig() {
116
+ return {
117
+ learningRate: this.learningRate,
118
+ beta1: this.beta1,
119
+ beta2: this.beta2,
120
+ epsilon: this.epsilon
121
+ };
122
+ }
123
+ static fromConfig(e, t) {
124
+ return new e(t.learningRate, t.beta1, t.beta2, t.epsilon);
125
+ }
126
+ };
127
+ //#endregion
128
+ export { f as AdamWOptimizer };
@@ -0,0 +1,63 @@
1
+ import { ITokeniser } from '../tokeniser/type';
2
+ import { Scalar, Tensor } from '@tensorflow/tfjs-core';
3
+ import { Dataset } from '@tensorflow/tfjs-data';
4
+ import { default as Model, ModelForwardAttributes } from '../../models/model';
5
+ import { AdamWOptimizerConfig, TrainingLogEntry, TrainingMetrics, TrainingOptions, TrainingState } from './types';
6
+ import { AdamWOptimizer } from './AdamW';
7
+ export default class BasicTrainer {
8
+ tokenizer: ITokeniser;
9
+ model: Model<ModelForwardAttributes>;
10
+ optimizer: AdamWOptimizer;
11
+ protected running: boolean;
12
+ protected lastState?: TrainingState;
13
+ protected _gradientCheckpointing: boolean;
14
+ protected _mixedPrecision: boolean;
15
+ protected maskedLoss: boolean;
16
+ protected optimizerConfig: AdamWOptimizerConfig;
17
+ protected metrics: Set<TrainingMetrics>;
18
+ protected _labelSmoothing: number;
19
+ protected _layerDrop: number;
20
+ protected _dropout: number;
21
+ constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, optConfig?: Partial<AdamWOptimizerConfig>, optimizer?: AdamWOptimizer);
22
+ setLossMasking(): void;
23
+ setGradientCheckpointing(enabled: boolean): void;
24
+ setMixedPrecision(enabled: boolean): void;
25
+ setLabelSmoothing(smoothing: number): void;
26
+ setDropout(dropout: number): void;
27
+ setLayerDrop(layerDrop: number): void;
28
+ setLearningRate(learningRate: number): void;
29
+ setMetrics(metrics: TrainingMetrics[]): void;
30
+ reset(): void;
31
+ stop(): void;
32
+ get isRunning(): boolean;
33
+ getOptimizer(): AdamWOptimizer;
34
+ updateOptimizer(config?: Partial<AdamWOptimizerConfig>): void;
35
+ resumeFromLog(log: TrainingLogEntry): void;
36
+ protected trainStep(state: Partial<TrainingState>, batch: {
37
+ xs: Tensor;
38
+ ys: Tensor;
39
+ }, dummy?: boolean, keepGrads?: boolean): Scalar;
40
+ private dummyPass;
41
+ dispose(): void;
42
+ private createEmptyState;
43
+ stepDataset(dataset: Dataset<{
44
+ xs: Tensor;
45
+ ys: Tensor;
46
+ }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
47
+ xs: Tensor;
48
+ ys: Tensor;
49
+ }>): Promise<{
50
+ log: TrainingLogEntry;
51
+ }>;
52
+ private performLogging;
53
+ trainOnDataset(dataset: Dataset<{
54
+ xs: Tensor;
55
+ ys: Tensor;
56
+ }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
57
+ xs: Tensor;
58
+ ys: Tensor;
59
+ }>): Promise<{
60
+ losses: number[];
61
+ validationLosses: number[];
62
+ }>;
63
+ }