@genai-fi/nanogpt 0.19.0 → 0.20.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. package/package.json +9 -10
  2. package/dist/Generator.d.ts +0 -82
  3. package/dist/Generator.js +0 -11941
  4. package/dist/RealDiv-CGwv0liw.js +0 -365
  5. package/dist/Reshape-BW__R4mZ.js +0 -79
  6. package/dist/Reshape-CPBkTIH2.js +0 -14
  7. package/dist/TeachableLLM.d.ts +0 -70
  8. package/dist/TeachableLLM.js +0 -273
  9. package/dist/Trainer.d.ts +0 -43
  10. package/dist/Trainer.js +0 -244
  11. package/dist/_commonjsHelpers-ByX85dGu.js +0 -33
  12. package/dist/axis_util-GTVlo58H.js +0 -55
  13. package/dist/backend.d.ts +0 -2
  14. package/dist/backend.js +0 -13
  15. package/dist/backend_util-GaFarB78.js +0 -425
  16. package/dist/backend_webgpu-BqASlsbV.js +0 -545
  17. package/dist/binary_op_util-pKXltfxI.js +0 -192
  18. package/dist/broadcast_to-eS93CCN_.js +0 -28
  19. package/dist/checks/appendCache.d.ts +0 -1
  20. package/dist/checks/appendCache.js +0 -22
  21. package/dist/checks/attentionMask.d.ts +0 -1
  22. package/dist/checks/attentionMask.js +0 -37
  23. package/dist/checks/check.d.ts +0 -9
  24. package/dist/checks/check.js +0 -20
  25. package/dist/checks/gelu.d.ts +0 -1
  26. package/dist/checks/gelu.js +0 -18
  27. package/dist/checks/index.d.ts +0 -26
  28. package/dist/checks/index.js +0 -28
  29. package/dist/checks/matMulGelu.d.ts +0 -1
  30. package/dist/checks/matMulGelu.js +0 -28
  31. package/dist/checks/normRMS.d.ts +0 -1
  32. package/dist/checks/normRMS.js +0 -16
  33. package/dist/checks/normRMSGrad.d.ts +0 -1
  34. package/dist/checks/normRMSGrad.js +0 -12
  35. package/dist/checks/packUnpack.d.ts +0 -1
  36. package/dist/checks/packUnpack.js +0 -18
  37. package/dist/checks/qkv.d.ts +0 -1
  38. package/dist/checks/qkv.js +0 -34
  39. package/dist/checks/rope.d.ts +0 -1
  40. package/dist/checks/rope.js +0 -36
  41. package/dist/checks/weights.d.ts +0 -14
  42. package/dist/checks/weights.js +0 -31
  43. package/dist/clip_by_value-DDA7rrcT.js +0 -12
  44. package/dist/complex-DI35Q-gW.js +0 -11
  45. package/dist/complex_util-Yc1A_gV1.js +0 -55
  46. package/dist/concat-CAQpCret.js +0 -17
  47. package/dist/concat_util-D18dJ4fD.js +0 -22
  48. package/dist/data/docx.d.ts +0 -2
  49. package/dist/data/docx.js +0 -15
  50. package/dist/data/parquet.d.ts +0 -2
  51. package/dist/data/parquet.js +0 -17
  52. package/dist/data/pdf.d.ts +0 -2
  53. package/dist/data/pdf.js +0 -14
  54. package/dist/data/textLoader.d.ts +0 -7
  55. package/dist/data/textLoader.js +0 -108
  56. package/dist/dataset-CGGp1z9P.js +0 -1124
  57. package/dist/dropout_util--NxWuYg2.js +0 -27
  58. package/dist/expand_dims-Bkd1YD5x.js +0 -11
  59. package/dist/exports_initializers-CYzKLjN7.js +0 -7
  60. package/dist/floor-BQtb-Azg.js +0 -9
  61. package/dist/gather-qIqEqaGn.js +0 -9
  62. package/dist/gelu-B220X1Go.js +0 -26
  63. package/dist/gpgpu_math-BwvV12df.js +0 -2022
  64. package/dist/index-CUXkjxiT.js +0 -3516
  65. package/dist/index-CieiGp4Y.js +0 -349
  66. package/dist/index-CjOWnMXP.js +0 -7308
  67. package/dist/index-Cp39cXWe.js +0 -1016
  68. package/dist/index-D5v913EJ.js +0 -4
  69. package/dist/index-DmeWGGmS.js +0 -1074
  70. package/dist/index-DvYrXKkX.js +0 -113
  71. package/dist/index-Ksja3su6.js +0 -151
  72. package/dist/index-xuotMAFm.js +0 -118
  73. package/dist/inference/types.d.ts +0 -16
  74. package/dist/inference/types.js +0 -1
  75. package/dist/jszip.min-BZhlzntC.js +0 -2313
  76. package/dist/kernel_funcs_utils-pq0CK9co.js +0 -306
  77. package/dist/layers/BaseLayer.d.ts +0 -44
  78. package/dist/layers/BaseLayer.js +0 -74
  79. package/dist/layers/CausalSelfAttention.d.ts +0 -39
  80. package/dist/layers/CausalSelfAttention.js +0 -86
  81. package/dist/layers/LoRA.d.ts +0 -14
  82. package/dist/layers/LoRA.js +0 -58
  83. package/dist/layers/MLP.d.ts +0 -17
  84. package/dist/layers/MLP.js +0 -44
  85. package/dist/layers/PositionEmbedding.d.ts +0 -8
  86. package/dist/layers/PositionEmbedding.js +0 -31
  87. package/dist/layers/RMSNorm.d.ts +0 -12
  88. package/dist/layers/RMSNorm.js +0 -22
  89. package/dist/layers/RoPECache.d.ts +0 -18
  90. package/dist/layers/RoPECache.js +0 -50
  91. package/dist/layers/TiedEmbedding.d.ts +0 -13
  92. package/dist/layers/TiedEmbedding.js +0 -36
  93. package/dist/layers/TransformerBlock.d.ts +0 -27
  94. package/dist/layers/TransformerBlock.js +0 -40
  95. package/dist/layers/WeightStore.d.ts +0 -20
  96. package/dist/layers/WeightStore.js +0 -76
  97. package/dist/loader/load.d.ts +0 -6
  98. package/dist/loader/load.js +0 -68
  99. package/dist/loader/loadHF.d.ts +0 -8
  100. package/dist/loader/loadHF.js +0 -22
  101. package/dist/loader/loadTransformers.d.ts +0 -4
  102. package/dist/loader/loadTransformers.js +0 -44
  103. package/dist/loader/loadZipMeta.d.ts +0 -3
  104. package/dist/loader/loadZipMeta.js +0 -16
  105. package/dist/loader/newZipLoad.d.ts +0 -3
  106. package/dist/loader/newZipLoad.js +0 -31
  107. package/dist/loader/oldZipLoad.d.ts +0 -9
  108. package/dist/loader/oldZipLoad.js +0 -80
  109. package/dist/loader/save.d.ts +0 -16
  110. package/dist/loader/save.js +0 -90
  111. package/dist/loader/types.d.ts +0 -67
  112. package/dist/loader/types.js +0 -1
  113. package/dist/main.d.ts +0 -50
  114. package/dist/main.js +0 -109
  115. package/dist/matMul16-BcVC_E62.js +0 -80
  116. package/dist/matMulGelu-JNLZqKQp.js +0 -163
  117. package/dist/mat_mul-DhG0Newp.js +0 -11
  118. package/dist/mod-CSdCpRjf.js +0 -11
  119. package/dist/models/NanoGPTV1.d.ts +0 -16
  120. package/dist/models/NanoGPTV1.js +0 -99
  121. package/dist/models/NanoGPTV2.d.ts +0 -16
  122. package/dist/models/NanoGPTV2.js +0 -90
  123. package/dist/models/config.d.ts +0 -27
  124. package/dist/models/config.js +0 -50
  125. package/dist/models/factory.d.ts +0 -3
  126. package/dist/models/factory.js +0 -16
  127. package/dist/models/model.d.ts +0 -44
  128. package/dist/models/model.js +0 -134
  129. package/dist/non_max_suppression_impl-B2W7YjZB.js +0 -102
  130. package/dist/not_equal-hurPF26l.js +0 -64
  131. package/dist/ones-BytntneX.js +0 -14
  132. package/dist/ops/adamAdjust.d.ts +0 -2
  133. package/dist/ops/adamAdjust.js +0 -9
  134. package/dist/ops/adamMoments.d.ts +0 -2
  135. package/dist/ops/adamMoments.js +0 -9
  136. package/dist/ops/add16.d.ts +0 -2
  137. package/dist/ops/add16.js +0 -9
  138. package/dist/ops/appendCache.d.ts +0 -2
  139. package/dist/ops/appendCache.js +0 -22
  140. package/dist/ops/attentionMask.d.ts +0 -2
  141. package/dist/ops/attentionMask.js +0 -10
  142. package/dist/ops/concat16.d.ts +0 -2
  143. package/dist/ops/concat16.js +0 -9
  144. package/dist/ops/cpu/adamAdjust.d.ts +0 -1
  145. package/dist/ops/cpu/adamAdjust.js +0 -18
  146. package/dist/ops/cpu/adamMoments.d.ts +0 -1
  147. package/dist/ops/cpu/adamMoments.js +0 -16
  148. package/dist/ops/cpu/appendCache.d.ts +0 -1
  149. package/dist/ops/cpu/appendCache.js +0 -23
  150. package/dist/ops/cpu/attentionMask.d.ts +0 -1
  151. package/dist/ops/cpu/attentionMask.js +0 -22
  152. package/dist/ops/cpu/fusedSoftmax.d.ts +0 -9
  153. package/dist/ops/cpu/fusedSoftmax.js +0 -29
  154. package/dist/ops/cpu/gatherSub.d.ts +0 -1
  155. package/dist/ops/cpu/gatherSub.js +0 -18
  156. package/dist/ops/cpu/gelu.d.ts +0 -1
  157. package/dist/ops/cpu/gelu.js +0 -40
  158. package/dist/ops/cpu/matMul16.d.ts +0 -1
  159. package/dist/ops/cpu/matMul16.js +0 -15
  160. package/dist/ops/cpu/matMulGelu.d.ts +0 -1
  161. package/dist/ops/cpu/matMulGelu.js +0 -53
  162. package/dist/ops/cpu/matMulMul.d.ts +0 -1
  163. package/dist/ops/cpu/matMulMul.js +0 -23
  164. package/dist/ops/cpu/mulDropout.d.ts +0 -1
  165. package/dist/ops/cpu/mulDropout.js +0 -23
  166. package/dist/ops/cpu/normRMS.d.ts +0 -1
  167. package/dist/ops/cpu/normRMS.js +0 -39
  168. package/dist/ops/cpu/qkv.d.ts +0 -5
  169. package/dist/ops/cpu/qkv.js +0 -41
  170. package/dist/ops/cpu/rope.d.ts +0 -6
  171. package/dist/ops/cpu/rope.js +0 -38
  172. package/dist/ops/cpu/scatterSub.d.ts +0 -1
  173. package/dist/ops/cpu/scatterSub.js +0 -23
  174. package/dist/ops/dot16.d.ts +0 -2
  175. package/dist/ops/dot16.js +0 -42
  176. package/dist/ops/dropout.d.ts +0 -2
  177. package/dist/ops/dropout.js +0 -14
  178. package/dist/ops/dropout16.d.ts +0 -2
  179. package/dist/ops/dropout16.js +0 -25
  180. package/dist/ops/gatherSub.d.ts +0 -2
  181. package/dist/ops/gatherSub.js +0 -9
  182. package/dist/ops/gelu.d.ts +0 -3
  183. package/dist/ops/gelu.js +0 -8
  184. package/dist/ops/globalNorm.d.ts +0 -2
  185. package/dist/ops/globalNorm.js +0 -13
  186. package/dist/ops/grads/add16.d.ts +0 -1
  187. package/dist/ops/grads/add16.js +0 -26
  188. package/dist/ops/grads/attentionMask.d.ts +0 -1
  189. package/dist/ops/grads/attentionMask.js +0 -21
  190. package/dist/ops/grads/dropout16.d.ts +0 -1
  191. package/dist/ops/grads/dropout16.js +0 -2
  192. package/dist/ops/grads/gelu.d.ts +0 -2
  193. package/dist/ops/grads/gelu.js +0 -5
  194. package/dist/ops/grads/matMul16.d.ts +0 -2
  195. package/dist/ops/grads/matMul16.js +0 -9
  196. package/dist/ops/grads/matMulGelu.d.ts +0 -1
  197. package/dist/ops/grads/matMulGelu.js +0 -17
  198. package/dist/ops/grads/mul16.d.ts +0 -1
  199. package/dist/ops/grads/mul16.js +0 -4
  200. package/dist/ops/grads/normRMS.d.ts +0 -3
  201. package/dist/ops/grads/normRMS.js +0 -33
  202. package/dist/ops/grads/pack16.d.ts +0 -2
  203. package/dist/ops/grads/pack16.js +0 -6
  204. package/dist/ops/grads/qkv.d.ts +0 -3
  205. package/dist/ops/grads/qkv.js +0 -34
  206. package/dist/ops/grads/rope.d.ts +0 -2
  207. package/dist/ops/grads/rope.js +0 -5
  208. package/dist/ops/grads/softmax16.d.ts +0 -2
  209. package/dist/ops/grads/softmax16.js +0 -25
  210. package/dist/ops/grads/unpack16.d.ts +0 -2
  211. package/dist/ops/grads/unpack16.js +0 -5
  212. package/dist/ops/grads/utils.d.ts +0 -4
  213. package/dist/ops/grads/utils.js +0 -14
  214. package/dist/ops/log.d.ts +0 -0
  215. package/dist/ops/log.js +0 -1
  216. package/dist/ops/matMul16.d.ts +0 -15
  217. package/dist/ops/matMul16.js +0 -13
  218. package/dist/ops/matMulGelu.d.ts +0 -3
  219. package/dist/ops/matMulGelu.js +0 -14
  220. package/dist/ops/matMulMul.d.ts +0 -2
  221. package/dist/ops/matMulMul.js +0 -9
  222. package/dist/ops/mul16.d.ts +0 -2
  223. package/dist/ops/mul16.js +0 -39
  224. package/dist/ops/mulDrop.d.ts +0 -2
  225. package/dist/ops/mulDrop.js +0 -9
  226. package/dist/ops/normRMS.d.ts +0 -2
  227. package/dist/ops/normRMS.js +0 -19
  228. package/dist/ops/pack16.d.ts +0 -2
  229. package/dist/ops/pack16.js +0 -5
  230. package/dist/ops/qkv.d.ts +0 -2
  231. package/dist/ops/qkv.js +0 -10
  232. package/dist/ops/reshape16.d.ts +0 -2
  233. package/dist/ops/reshape16.js +0 -41
  234. package/dist/ops/rope.d.ts +0 -3
  235. package/dist/ops/rope.js +0 -7
  236. package/dist/ops/scatterSub.d.ts +0 -2
  237. package/dist/ops/scatterSub.js +0 -9
  238. package/dist/ops/slice16.d.ts +0 -2
  239. package/dist/ops/slice16.js +0 -9
  240. package/dist/ops/softmax16.d.ts +0 -2
  241. package/dist/ops/softmax16.js +0 -9
  242. package/dist/ops/sub16.d.ts +0 -2
  243. package/dist/ops/sub16.js +0 -8
  244. package/dist/ops/sum16.d.ts +0 -2
  245. package/dist/ops/sum16.js +0 -13
  246. package/dist/ops/transpose16.d.ts +0 -3
  247. package/dist/ops/transpose16.js +0 -40
  248. package/dist/ops/unpack16.d.ts +0 -2
  249. package/dist/ops/unpack16.js +0 -6
  250. package/dist/ops/webgl/adamAdjust.d.ts +0 -1
  251. package/dist/ops/webgl/adamAdjust.js +0 -49
  252. package/dist/ops/webgl/adamMoments.d.ts +0 -1
  253. package/dist/ops/webgl/adamMoments.js +0 -40
  254. package/dist/ops/webgl/appendCache.d.ts +0 -1
  255. package/dist/ops/webgl/appendCache.js +0 -44
  256. package/dist/ops/webgl/attentionMask.d.ts +0 -1
  257. package/dist/ops/webgl/attentionMask.js +0 -45
  258. package/dist/ops/webgl/dropout16.d.ts +0 -1
  259. package/dist/ops/webgl/dropout16.js +0 -11
  260. package/dist/ops/webgl/fusedSoftmax.d.ts +0 -11
  261. package/dist/ops/webgl/fusedSoftmax.js +0 -80
  262. package/dist/ops/webgl/gatherSub.d.ts +0 -1
  263. package/dist/ops/webgl/gatherSub.js +0 -27
  264. package/dist/ops/webgl/gelu.d.ts +0 -2
  265. package/dist/ops/webgl/gelu.js +0 -50
  266. package/dist/ops/webgl/log.d.ts +0 -17
  267. package/dist/ops/webgl/log.js +0 -23
  268. package/dist/ops/webgl/matMul16.d.ts +0 -1
  269. package/dist/ops/webgl/matMul16.js +0 -45
  270. package/dist/ops/webgl/matMulGelu.d.ts +0 -21
  271. package/dist/ops/webgl/matMulGelu.js +0 -9
  272. package/dist/ops/webgl/matMulMul.d.ts +0 -14
  273. package/dist/ops/webgl/matMulMul.js +0 -28
  274. package/dist/ops/webgl/mulDropout.d.ts +0 -1
  275. package/dist/ops/webgl/mulDropout.js +0 -41
  276. package/dist/ops/webgl/normRMS.d.ts +0 -1
  277. package/dist/ops/webgl/normRMS.js +0 -93
  278. package/dist/ops/webgl/qkv.d.ts +0 -1
  279. package/dist/ops/webgl/qkv.js +0 -46
  280. package/dist/ops/webgl/rope.d.ts +0 -1
  281. package/dist/ops/webgl/rope.js +0 -56
  282. package/dist/ops/webgl/scatterSub.d.ts +0 -1
  283. package/dist/ops/webgl/scatterSub.js +0 -27
  284. package/dist/ops/webgpu/adamAdjust.d.ts +0 -1
  285. package/dist/ops/webgpu/adamAdjust.js +0 -57
  286. package/dist/ops/webgpu/adamMoments.d.ts +0 -1
  287. package/dist/ops/webgpu/adamMoments.js +0 -60
  288. package/dist/ops/webgpu/add16.d.ts +0 -1
  289. package/dist/ops/webgpu/add16.js +0 -13
  290. package/dist/ops/webgpu/appendCache.d.ts +0 -1
  291. package/dist/ops/webgpu/appendCache.js +0 -105
  292. package/dist/ops/webgpu/attentionMask.d.ts +0 -1
  293. package/dist/ops/webgpu/attentionMask.js +0 -26
  294. package/dist/ops/webgpu/attentionMask32_program.d.ts +0 -19
  295. package/dist/ops/webgpu/attentionMask32_program.js +0 -54
  296. package/dist/ops/webgpu/clipScale.d.ts +0 -1
  297. package/dist/ops/webgpu/clipScale.js +0 -58
  298. package/dist/ops/webgpu/concat16.d.ts +0 -19
  299. package/dist/ops/webgpu/concat16.js +0 -126
  300. package/dist/ops/webgpu/dropout16.d.ts +0 -1
  301. package/dist/ops/webgpu/dropout16.js +0 -51
  302. package/dist/ops/webgpu/gatherSub.d.ts +0 -1
  303. package/dist/ops/webgpu/gatherSub.js +0 -39
  304. package/dist/ops/webgpu/gelu.d.ts +0 -14
  305. package/dist/ops/webgpu/gelu.js +0 -141
  306. package/dist/ops/webgpu/index.d.ts +0 -0
  307. package/dist/ops/webgpu/index.js +0 -26
  308. package/dist/ops/webgpu/matMul16.d.ts +0 -1
  309. package/dist/ops/webgpu/matMul16.js +0 -65
  310. package/dist/ops/webgpu/matMul16_program.d.ts +0 -42
  311. package/dist/ops/webgpu/matMul16_program.js +0 -343
  312. package/dist/ops/webgpu/mul16.d.ts +0 -1
  313. package/dist/ops/webgpu/mul16.js +0 -13
  314. package/dist/ops/webgpu/norm2.d.ts +0 -1
  315. package/dist/ops/webgpu/norm2.js +0 -76
  316. package/dist/ops/webgpu/normRMS.d.ts +0 -1
  317. package/dist/ops/webgpu/normRMS.js +0 -34
  318. package/dist/ops/webgpu/normRMS16_program.d.ts +0 -10
  319. package/dist/ops/webgpu/normRMS16_program.js +0 -25
  320. package/dist/ops/webgpu/normRMS32_program.d.ts +0 -10
  321. package/dist/ops/webgpu/normRMS32_program.js +0 -25
  322. package/dist/ops/webgpu/normRMSGrad.d.ts +0 -1
  323. package/dist/ops/webgpu/normRMSGrad.js +0 -284
  324. package/dist/ops/webgpu/pack16.d.ts +0 -1
  325. package/dist/ops/webgpu/pack16.js +0 -18
  326. package/dist/ops/webgpu/pack16_program.d.ts +0 -19
  327. package/dist/ops/webgpu/pack16_program.js +0 -92
  328. package/dist/ops/webgpu/qkv.d.ts +0 -1
  329. package/dist/ops/webgpu/qkv.js +0 -24
  330. package/dist/ops/webgpu/rope.d.ts +0 -1
  331. package/dist/ops/webgpu/rope.js +0 -135
  332. package/dist/ops/webgpu/scatterSub.d.ts +0 -1
  333. package/dist/ops/webgpu/scatterSub.js +0 -40
  334. package/dist/ops/webgpu/slice16.d.ts +0 -7
  335. package/dist/ops/webgpu/slice16.js +0 -69
  336. package/dist/ops/webgpu/softmax16.d.ts +0 -17
  337. package/dist/ops/webgpu/softmax16.js +0 -21
  338. package/dist/ops/webgpu/softmax16_program.d.ts +0 -13
  339. package/dist/ops/webgpu/softmax16_program.js +0 -73
  340. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +0 -17
  341. package/dist/ops/webgpu/softmax16_subgroup_program.js +0 -75
  342. package/dist/ops/webgpu/softmax16grad.d.ts +0 -1
  343. package/dist/ops/webgpu/softmax16grad.js +0 -37
  344. package/dist/ops/webgpu/sub16.d.ts +0 -1
  345. package/dist/ops/webgpu/sub16.js +0 -13
  346. package/dist/ops/webgpu/sum16.d.ts +0 -1
  347. package/dist/ops/webgpu/sum16.js +0 -38
  348. package/dist/ops/webgpu/transpose16.d.ts +0 -1
  349. package/dist/ops/webgpu/transpose16.js +0 -34
  350. package/dist/ops/webgpu/transpose16_program.d.ts +0 -16
  351. package/dist/ops/webgpu/transpose16_program.js +0 -50
  352. package/dist/ops/webgpu/transpose16_shared_program.d.ts +0 -15
  353. package/dist/ops/webgpu/transpose16_shared_program.js +0 -70
  354. package/dist/ops/webgpu/unpack16.d.ts +0 -1
  355. package/dist/ops/webgpu/unpack16.js +0 -48
  356. package/dist/ops/webgpu/utils/binary_op.d.ts +0 -35
  357. package/dist/ops/webgpu/utils/binary_op.js +0 -139
  358. package/dist/ops/webgpu/utils/deviceInfo.d.ts +0 -7
  359. package/dist/ops/webgpu/utils/deviceInfo.js +0 -11
  360. package/dist/ops/webgpu/utils/reductions.d.ts +0 -43
  361. package/dist/ops/webgpu/utils/reductions.js +0 -275
  362. package/dist/ops-CsXeTq1P.js +0 -476
  363. package/dist/pack16-bqltoUlR.js +0 -39
  364. package/dist/papaparse.min-C0cScC2i.js +0 -418
  365. package/dist/parquet-Bqjmp2vo.js +0 -44231
  366. package/dist/patches/webgpu_backend.d.ts +0 -18
  367. package/dist/patches/webgpu_backend.js +0 -56
  368. package/dist/patches/webgpu_base.d.ts +0 -21
  369. package/dist/patches/webgpu_base.js +0 -34
  370. package/dist/patches/webgpu_program.d.ts +0 -36
  371. package/dist/patches/webgpu_program.js +0 -400
  372. package/dist/pdf-NIhmP3sq.js +0 -19477
  373. package/dist/rand_util-CZ7yLoUm.js +0 -50
  374. package/dist/random_normal-IBRrha8a.js +0 -14
  375. package/dist/random_width-DN5ZtQkM.js +0 -9796
  376. package/dist/range-C-CjF-LI.js +0 -10
  377. package/dist/relu-J_X6MUzx.js +0 -9
  378. package/dist/reshape-BDOuCSNW.js +0 -9
  379. package/dist/resize_nearest_neighbor-BojqlfRe.js +0 -150
  380. package/dist/rope-DcrZM_e6.js +0 -24
  381. package/dist/scatter_nd_util-ByNJaL6I.js +0 -46
  382. package/dist/segment_util-Dasb2Zaf.js +0 -43
  383. package/dist/selu_util-BLhIqRkw.js +0 -44
  384. package/dist/shared-3agzAqQ_.js +0 -53
  385. package/dist/shared-CagdqkLh.js +0 -2143
  386. package/dist/slice-BzS11Qh0.js +0 -12
  387. package/dist/slice_util-CC35pLmT.js +0 -153
  388. package/dist/softmax-D4q1LJN7.js +0 -12
  389. package/dist/split-C2Sj255c.js +0 -9
  390. package/dist/squeeze-ho4wLUek.js +0 -10
  391. package/dist/stack-DudVrtmG.js +0 -11
  392. package/dist/step-BTxPtq1r.js +0 -261
  393. package/dist/sum-BpiwSWvg.js +0 -11
  394. package/dist/tensor-BWFldCso.js +0 -8
  395. package/dist/tensor1d-LMGMIUlr.js +0 -11
  396. package/dist/tensor2d-BnXMKScO.js +0 -14
  397. package/dist/tensor4d-C6UCG_u8.js +0 -14
  398. package/dist/tfjs_backend-BGnG-ppu.js +0 -654
  399. package/dist/tile-CFy-xTO6.js +0 -11
  400. package/dist/tokeniser/BaseTokeniser.d.ts +0 -33
  401. package/dist/tokeniser/BaseTokeniser.js +0 -124
  402. package/dist/tokeniser/CharTokeniser.d.ts +0 -24
  403. package/dist/tokeniser/CharTokeniser.js +0 -107
  404. package/dist/tokeniser/bpe.d.ts +0 -28
  405. package/dist/tokeniser/bpe.js +0 -173
  406. package/dist/tokeniser/messages.d.ts +0 -61
  407. package/dist/tokeniser/messages.js +0 -1
  408. package/dist/tokeniser/type.d.ts +0 -34
  409. package/dist/tokeniser/type.js +0 -1
  410. package/dist/training/AdamW.d.ts +0 -36
  411. package/dist/training/AdamW.js +0 -138
  412. package/dist/training/BasicTrainer.d.ts +0 -63
  413. package/dist/training/BasicTrainer.js +0 -265
  414. package/dist/training/DatasetBuilder.d.ts +0 -26
  415. package/dist/training/DatasetBuilder.js +0 -86
  416. package/dist/training/Evaluator.d.ts +0 -19
  417. package/dist/training/Evaluator.js +0 -39
  418. package/dist/training/LRScheduler.d.ts +0 -12
  419. package/dist/training/LRScheduler.js +0 -34
  420. package/dist/training/PreTrainer.d.ts +0 -11
  421. package/dist/training/PreTrainer.js +0 -20
  422. package/dist/training/SFTTrainer.d.ts +0 -12
  423. package/dist/training/SFTTrainer.js +0 -22
  424. package/dist/training/loss.d.ts +0 -3
  425. package/dist/training/loss.js +0 -24
  426. package/dist/training/orthoGrad.d.ts +0 -2
  427. package/dist/training/orthoGrad.js +0 -10
  428. package/dist/training/sparseCrossEntropy.d.ts +0 -7
  429. package/dist/training/sparseCrossEntropy.js +0 -69
  430. package/dist/training/tasks/ConversationTask.d.ts +0 -18
  431. package/dist/training/tasks/ConversationTask.js +0 -40
  432. package/dist/training/tasks/PretrainingTask.d.ts +0 -17
  433. package/dist/training/tasks/PretrainingTask.js +0 -47
  434. package/dist/training/tasks/StartSentenceTask.d.ts +0 -18
  435. package/dist/training/tasks/StartSentenceTask.js +0 -49
  436. package/dist/training/tasks/Task.d.ts +0 -22
  437. package/dist/training/tasks/Task.js +0 -68
  438. package/dist/training/tasks/splitter.d.ts +0 -5
  439. package/dist/training/tasks/splitter.js +0 -21
  440. package/dist/training/types.d.ts +0 -78
  441. package/dist/training/types.js +0 -1
  442. package/dist/training/validation.d.ts +0 -17
  443. package/dist/training/validation.js +0 -84
  444. package/dist/transpose-9kRxIXWR.js +0 -36
  445. package/dist/unsorted_segment_sum-DJvk5xnh.js +0 -277
  446. package/dist/utilities/arrayClose.d.ts +0 -1
  447. package/dist/utilities/arrayClose.js +0 -20
  448. package/dist/utilities/datasetID.d.ts +0 -2
  449. package/dist/utilities/datasetID.js +0 -21
  450. package/dist/utilities/dummy.d.ts +0 -9
  451. package/dist/utilities/dummy.js +0 -43
  452. package/dist/utilities/multinomialCPU.d.ts +0 -2
  453. package/dist/utilities/multinomialCPU.js +0 -13
  454. package/dist/utilities/naming.d.ts +0 -4
  455. package/dist/utilities/naming.js +0 -1
  456. package/dist/utilities/packed.d.ts +0 -4
  457. package/dist/utilities/packed.js +0 -15
  458. package/dist/utilities/parameters.d.ts +0 -11
  459. package/dist/utilities/parameters.js +0 -57
  460. package/dist/utilities/performance.d.ts +0 -2
  461. package/dist/utilities/performance.js +0 -16
  462. package/dist/utilities/profile.d.ts +0 -17
  463. package/dist/utilities/profile.js +0 -38
  464. package/dist/utilities/safetensors.d.ts +0 -3
  465. package/dist/utilities/safetensors.js +0 -83
  466. package/dist/utilities/sentences.d.ts +0 -5
  467. package/dist/utilities/sentences.js +0 -41
  468. package/dist/utilities/tokenParse.d.ts +0 -1
  469. package/dist/utilities/tokenParse.js +0 -21
  470. package/dist/utilities/topP.d.ts +0 -1
  471. package/dist/utilities/topP.js +0 -13
  472. package/dist/utilities/waitForModel.d.ts +0 -2
  473. package/dist/utilities/waitForModel.js +0 -12
  474. package/dist/utilities/weights.d.ts +0 -12
  475. package/dist/utilities/weights.js +0 -45
  476. package/dist/utilities/yielder.d.ts +0 -1
  477. package/dist/utilities/yielder.js +0 -7
  478. package/dist/variable-Ck482e3n.js +0 -7
  479. package/dist/webgpu_program-B4HmApL1.js +0 -525
  480. package/dist/webgpu_util-DYlGSwOJ.js +0 -64
  481. package/dist/zeros-DvZpK8s6.js +0 -13
  482. package/dist/zeros_like-CWjDdwr-.js +0 -721
@@ -1,138 +0,0 @@
1
- import { adamAdjust as B } from "../ops/adamAdjust.js";
2
- import { adamMoments as N } from "../ops/adamMoments.js";
3
- import { O as S, h as b, t as c, a as M, d as w } from "../index-CUXkjxiT.js";
4
- import R from "./LRScheduler.js";
5
- import { clipScale as f } from "../ops/globalNorm.js";
6
- import { save_safetensors as v, load_safetensors as A } from "../utilities/safetensors.js";
7
- import { z as O } from "../zeros-DvZpK8s6.js";
8
- class _ extends S {
9
- constructor(t) {
10
- super(), this.config = t, this.accBeta1 = t.accBeta1 ?? t.beta1, this.accBeta2 = t.accBeta2 ?? t.beta2, this.learningRate = t.learningRate, this.beta1 = t.beta1, this.beta2 = t.beta2, this.weightDecay = t.weightDecay, this.lossScaling = t.lossScaling, this.clipNorm = t.clipNorm, this.orthGrad = t.orthoGrad ?? !1, t.epsilon === null || t.epsilon === void 0 ? this.epsilon = b().backend.epsilon() : this.epsilon = t.epsilon, this.lrScheduler = new R(t.learningRate, t);
11
- }
12
- className = "AdamW";
13
- accBeta1 = 0;
14
- accBeta2 = 0;
15
- accumulatedMoments = [];
16
- learningRate;
17
- beta1;
18
- beta2;
19
- lossScaling;
20
- weightDecay;
21
- epsilon = null;
22
- lrScheduler;
23
- clipNorm;
24
- orthGradEpsilon = 1e-30;
25
- orthGrad;
26
- get lr() {
27
- return this.learningRate;
28
- }
29
- saveMoments() {
30
- const t = {};
31
- return this.accumulatedMoments.forEach((e) => {
32
- t[e.originalName] = e.variable;
33
- }), v(t);
34
- }
35
- async loadMoments(t) {
36
- const e = await A(t);
37
- Object.entries(e).forEach(([a, s]) => {
38
- const n = s.variable(!1);
39
- this.accumulatedMoments.push({ originalName: a, variable: n });
40
- });
41
- }
42
- serializeConfig() {
43
- return {
44
- learningRate: this.learningRate,
45
- beta1: this.beta1,
46
- beta2: this.beta2,
47
- accBeta1: this.accBeta1,
48
- accBeta2: this.accBeta2,
49
- epsilon: this.epsilon ?? void 0,
50
- weightDecay: this.weightDecay,
51
- lossScaling: this.lossScaling,
52
- clipNorm: this.clipNorm,
53
- orthoGrad: this.orthGrad,
54
- ...this.lrScheduler.serializeConfig()
55
- };
56
- }
57
- orthogonalizeGradient(t, e) {
58
- return c(() => {
59
- const a = t.reshape([-1]), s = e.reshape([-1]), n = a.mul(a).sum().add(this.orthGradEpsilon), h = a.mul(s).sum().div(n), o = s.sub(a.mul(h)), l = s.norm(), i = o.norm().add(this.orthGradEpsilon);
60
- return o.mul(l.div(i)).reshape(e.shape);
61
- });
62
- }
63
- updateConfig(t) {
64
- const e = { ...this.config, ...t };
65
- this.learningRate = e.learningRate, this.beta1 = e.beta1, this.beta2 = e.beta2, this.weightDecay = e.weightDecay, this.lossScaling = e.lossScaling, this.epsilon = e.epsilon ?? this.epsilon, this.clipNorm = e.clipNorm, this.lrScheduler.updateConfig(e, e.learningRate);
66
- }
67
- applyGradients(t) {
68
- const e = this.lrScheduler.getNextLR();
69
- this.learningRate = e;
70
- const a = Array.isArray(t) ? t.map((n) => n.name) : Object.keys(t), s = c(() => {
71
- const n = 1 - this.accBeta1, h = 1 - this.accBeta2;
72
- let o;
73
- if (this.clipNorm !== void 0) {
74
- const l = a.map((i, r) => Array.isArray(t) ? t[r].tensor : t[i]);
75
- o = f(l, 1 / this.lossScaling, this.clipNorm);
76
- } else
77
- o = M(1 / this.lossScaling);
78
- return a.forEach((l, i) => {
79
- const r = b().registeredVariables[l], p = !1;
80
- this.accumulatedMoments[i] == null && (this.accumulatedMoments[i] = {
81
- originalName: `${l}/m`,
82
- variable: c(() => O([...r.shape, 2]).variable(p))
83
- });
84
- const m = Array.isArray(t) ? t[i].tensor : t[l];
85
- if (m == null)
86
- return;
87
- const u = this.orthGrad ? this.orthogonalizeGradient(r, m) : m, d = this.accumulatedMoments[i].variable, g = N(d, u, this.beta1, this.beta2, o);
88
- d.assign(g), this.orthGrad && u.dispose();
89
- const y = B(
90
- g,
91
- r,
92
- n,
93
- h,
94
- this.epsilon ?? 1e-8,
95
- this.learningRate,
96
- // Only apply weight decay if the variable is multi-dimensional (e.g. weights, not biases)
97
- r.shape.length > 1 ? this.weightDecay : 0
98
- );
99
- r.assign(y);
100
- }), this.accBeta1 = this.accBeta1 * this.beta1, this.accBeta2 = this.accBeta2 * this.beta2, o;
101
- });
102
- return this.incrementIterations(), s;
103
- }
104
- dispose() {
105
- this.accumulatedMoments != null && w(this.accumulatedMoments.map((t) => t.variable));
106
- }
107
- async getWeights() {
108
- const t = [...this.accumulatedMoments];
109
- return [await this.saveIterations()].concat(
110
- t.map((e) => ({ name: e.originalName, tensor: e.variable }))
111
- );
112
- }
113
- async setWeights(t) {
114
- t = await this.extractIterations(t), c(() => {
115
- this.accBeta1 = Math.pow(this.beta1, this.iterations_ + 1), this.accBeta2 = Math.pow(this.beta2, this.iterations_ + 1);
116
- });
117
- const e = t.length / 2, a = !1;
118
- this.accumulatedMoments = t.slice(0, e).map((s) => ({
119
- originalName: s.name,
120
- variable: s.tensor.variable(a)
121
- }));
122
- }
123
- getConfig() {
124
- return {
125
- learningRate: this.learningRate,
126
- beta1: this.beta1,
127
- beta2: this.beta2,
128
- epsilon: this.epsilon
129
- };
130
- }
131
- /** @nocollapse */
132
- static fromConfig(t, e) {
133
- return new t(e.learningRate, e.beta1, e.beta2, e.epsilon);
134
- }
135
- }
136
- export {
137
- _ as AdamWOptimizer
138
- };
@@ -1,63 +0,0 @@
1
- import { ITokeniser } from '../tokeniser/type';
2
- import { Scalar, Tensor } from '@tensorflow/tfjs-core';
3
- import { Dataset } from '@tensorflow/tfjs-data';
4
- import { default as Model, ModelForwardAttributes } from '../models/model';
5
- import { AdamWOptimizerConfig, TrainingLogEntry, TrainingMetrics, TrainingOptions, TrainingState } from './types';
6
- import { AdamWOptimizer } from './AdamW';
7
- export default class BasicTrainer {
8
- tokenizer: ITokeniser;
9
- model: Model<ModelForwardAttributes>;
10
- optimizer: AdamWOptimizer;
11
- protected running: boolean;
12
- protected lastState?: TrainingState;
13
- protected _gradientCheckpointing: boolean;
14
- protected _mixedPrecision: boolean;
15
- protected maskedLoss: boolean;
16
- protected optimizerConfig: AdamWOptimizerConfig;
17
- protected metrics: Set<TrainingMetrics>;
18
- protected _labelSmoothing: number;
19
- protected _layerDrop: number;
20
- protected _dropout: number;
21
- constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, optConfig?: Partial<AdamWOptimizerConfig>, optimizer?: AdamWOptimizer);
22
- setLossMasking(): void;
23
- setGradientCheckpointing(enabled: boolean): void;
24
- setMixedPrecision(enabled: boolean): void;
25
- setLabelSmoothing(smoothing: number): void;
26
- setDropout(dropout: number): void;
27
- setLayerDrop(layerDrop: number): void;
28
- setLearningRate(learningRate: number): void;
29
- setMetrics(metrics: TrainingMetrics[]): void;
30
- reset(): void;
31
- stop(): void;
32
- get isRunning(): boolean;
33
- getOptimizer(): AdamWOptimizer;
34
- updateOptimizer(config?: Partial<AdamWOptimizerConfig>): void;
35
- resumeFromLog(log: TrainingLogEntry): void;
36
- protected trainStep(state: Partial<TrainingState>, batch: {
37
- xs: Tensor;
38
- ys: Tensor;
39
- }, dummy?: boolean, keepGrads?: boolean): Scalar;
40
- private dummyPass;
41
- dispose(): void;
42
- private createEmptyState;
43
- stepDataset(dataset: Dataset<{
44
- xs: Tensor;
45
- ys: Tensor;
46
- }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
47
- xs: Tensor;
48
- ys: Tensor;
49
- }>): Promise<{
50
- log: TrainingLogEntry;
51
- }>;
52
- private performLogging;
53
- trainOnDataset(dataset: Dataset<{
54
- xs: Tensor;
55
- ys: Tensor;
56
- }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
57
- xs: Tensor;
58
- ys: Tensor;
59
- }>): Promise<{
60
- losses: number[];
61
- validationLosses: number[];
62
- }>;
63
- }
@@ -1,265 +0,0 @@
1
- import S from "./Evaluator.js";
2
- import { t as k, v as x, k as y, d as u, a as w } from "../index-CUXkjxiT.js";
3
- import v from "../utilities/profile.js";
4
- import { createTensorStatistics as N } from "../checks/weights.js";
5
- import { calculateLoss as b, calculateAccuracy as P } from "./loss.js";
6
- import { AdamWOptimizer as T } from "./AdamW.js";
7
- import { z as L } from "../zeros-DvZpK8s6.js";
8
- const z = {
9
- logInterval: 1,
10
- maxEpochs: 100,
11
- sftMode: "full",
12
- batchSize: 32
13
- }, D = {
14
- learningRate: 3e-4,
15
- beta1: 0.9,
16
- beta2: 0.99,
17
- epsilon: 1e-8,
18
- weightDecay: 0.01,
19
- warmupSteps: 100,
20
- decayEpochs: 100,
21
- epochSteps: 1e4,
22
- minLearningRate: 1e-5,
23
- lossScaling: 1
24
- };
25
- class B {
26
- constructor(s, i, n, l) {
27
- this.tokenizer = i, this.model = s, this.optimizerConfig = {
28
- ...D,
29
- ...n,
30
- lossScaling: s.lossScaling
31
- };
32
- const d = l || new T(this.optimizerConfig);
33
- l && l.updateConfig(this.optimizerConfig), this.optimizer = d;
34
- }
35
- model;
36
- optimizer;
37
- running = !1;
38
- lastState;
39
- _gradientCheckpointing = !1;
40
- _mixedPrecision = !1;
41
- maskedLoss = !1;
42
- optimizerConfig;
43
- metrics = /* @__PURE__ */ new Set();
44
- _labelSmoothing = 0;
45
- _layerDrop = 0;
46
- _dropout = 0;
47
- setLossMasking() {
48
- this.maskedLoss = !0;
49
- }
50
- setGradientCheckpointing(s) {
51
- this._gradientCheckpointing = s;
52
- }
53
- setMixedPrecision(s) {
54
- this._mixedPrecision = s;
55
- }
56
- setLabelSmoothing(s) {
57
- this._labelSmoothing = s;
58
- }
59
- setDropout(s) {
60
- this._dropout = s;
61
- }
62
- setLayerDrop(s) {
63
- this._layerDrop = s;
64
- }
65
- setLearningRate(s) {
66
- this.optimizerConfig.learningRate = s, this.updateOptimizer();
67
- }
68
- setMetrics(s) {
69
- this.metrics = new Set(s);
70
- }
71
- reset() {
72
- this.lastState = void 0, this.running = !1;
73
- }
74
- stop() {
75
- this.running = !1;
76
- }
77
- get isRunning() {
78
- return this.running;
79
- }
80
- getOptimizer() {
81
- return this.optimizer;
82
- }
83
- updateOptimizer(s) {
84
- s && (this.optimizerConfig = { ...this.optimizerConfig, ...s }), this.optimizer.updateConfig(this.optimizerConfig);
85
- }
86
- resumeFromLog(s) {
87
- (!this.lastState || this.lastState.step === 0) && (this.lastState = {
88
- losses: [],
89
- validationLosses: [],
90
- logStartTime: 0,
91
- step: s.step,
92
- lastLoss: s.trainingMetrics.loss,
93
- totalSteps: s.step,
94
- trainingDuration: s.duration
95
- });
96
- }
97
- // A single forward pass, backward pass, and optimizer step
98
- trainStep(s, i, n = !1, l = !1) {
99
- return k(() => {
100
- this.model.getProfiler()?.startMemory();
101
- const { xs: d, ys: r } = i, m = () => {
102
- const a = this.model.forward(
103
- {
104
- training: !0,
105
- checkpointing: this._gradientCheckpointing,
106
- mixedPrecision: this._mixedPrecision,
107
- dropout: this._dropout,
108
- layerDrop: this._layerDrop,
109
- ropePositionOffset: 0
110
- },
111
- d
112
- ), o = b(a, r, this.maskedLoss, !1, this._labelSmoothing);
113
- this.metrics.has("accuracy") && (s.accuracy = P(a, r), y(s.accuracy)), a.dispose();
114
- const e = o.mul(w(this.optimizerConfig.lossScaling));
115
- return o.dispose(), e;
116
- }, { value: t, grads: c } = x(m);
117
- if (n)
118
- this.model.getProfiler()?.endMemory("Training");
119
- else {
120
- const a = this.optimizer.applyGradients(c);
121
- this.metrics.has("gradientNorm") ? (s.gradientNorm = a, y(a)) : (s.gradientNorm = void 0, a.dispose());
122
- const o = Object.keys(c);
123
- this.model.weightStore.touchVariables(o), this.model.getProfiler()?.endMemory("Training"), l ? (s.gradients = c, Object.values(c).forEach((e) => y(e))) : u(c);
124
- }
125
- return t.mul(w(1 / this.optimizerConfig.lossScaling));
126
- });
127
- }
128
- async dummyPass() {
129
- const s = L([1, this.model.config.blockSize], "int32"), i = L([1, this.model.config.blockSize], "int32");
130
- try {
131
- const n = this.trainStep({}, { xs: s, ys: i }, !0);
132
- await n.data(), n.dispose();
133
- } catch (n) {
134
- console.error("Error during dummy pass:", n);
135
- } finally {
136
- s.dispose(), i.dispose();
137
- }
138
- }
139
- dispose() {
140
- this.optimizer && this.optimizer.dispose();
141
- }
142
- createEmptyState() {
143
- return {
144
- step: 0,
145
- lastLoss: 1e6,
146
- totalSteps: 0,
147
- losses: [],
148
- validationLosses: [],
149
- logStartTime: 0,
150
- trainingDuration: 0,
151
- ...this.lastState || {}
152
- };
153
- }
154
- async stepDataset(s, i, n) {
155
- const { logInterval: l = 10 } = {
156
- ...z,
157
- ...i
158
- };
159
- i.metrics && this.setMetrics(i.metrics);
160
- const d = Date.now(), r = this.createEmptyState();
161
- this.lastState = r, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new v())), this.running = !0, r.logStartTime = d;
162
- const m = n ? new S(this.model, n, this.maskedLoss) : void 0, t = await s.iterator();
163
- try {
164
- for (; this.running; ) {
165
- const c = await t.next();
166
- if (c.done) break;
167
- const a = c.value, o = this.trainStep(r, a, !1);
168
- if (i.debug) {
169
- const e = (await o.data())[0];
170
- if (isNaN(e) || !isFinite(e))
171
- throw console.error("Invalid loss value:", e), console.error("Batch xs:", a.xs.toString()), console.error("Batch ys:", a.ys.toString()), console.error("State:", r), new Error("Loss is NaN or Infinity");
172
- console.log(`Step ${r.step}: Loss = ${e}`);
173
- }
174
- a.xs.dispose(), a.ys.dispose(), r.step++, r.totalSteps++, r.step % l === 0 ? await this.performLogging(o, a.xs.shape[0], i, m) : (r.gradientNorm && (r.gradientNorm.dispose(), r.gradientNorm = void 0), r.accuracy && (r.accuracy.dispose(), r.accuracy = void 0)), o.dispose();
175
- }
176
- } catch (c) {
177
- throw console.error("Training error:", c), c;
178
- }
179
- throw this.model.trainingState = {
180
- steps: r.totalSteps,
181
- learningRate: this.optimizer.lr,
182
- batchSize: i.batchSize || 32,
183
- loss: r.lastLoss,
184
- tokensProcessed: r.totalSteps * (i.batchSize || 32) * this.model.config.blockSize,
185
- duration: r.trainingDuration
186
- }, u(), this.running = !1, new Error("No log returned before training stopped.");
187
- }
188
- async performLogging(s, i, n, l) {
189
- const d = n?.onStep, r = this.metrics.has("gradientStatistics"), m = (await s.data())[0], t = this.lastState;
190
- t.lastLoss = m;
191
- const c = Date.now();
192
- t.trainingDuration += c - t.logStartTime;
193
- const a = t.totalSteps * i * this.model.config.blockSize, o = {
194
- trainingMetrics: {
195
- loss: t.lastLoss,
196
- perplexity: this.metrics.has("perplexity") ? Math.exp(t.lastLoss) : void 0,
197
- accuracy: t.accuracy ? (await t.accuracy.data())[0] : void 0
198
- },
199
- step: t.step,
200
- time: Date.now() - t.logStartTime,
201
- gradientNorm: t.gradientNorm ? (await t.gradientNorm.data())[1] : void 0,
202
- batchSize: i,
203
- learningRate: this.metrics.has("learningRate") ? this.optimizer.lr : void 0,
204
- duration: t.trainingDuration,
205
- totalTokens: a,
206
- tokensPerSecond: a / (t.trainingDuration / 1e3),
207
- memoryUsage: this.metrics.has("memoryUsage") ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
208
- };
209
- if (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0), this.model.trainingState = {
210
- steps: t.totalSteps,
211
- learningRate: this.optimizer.lr,
212
- batchSize: i,
213
- loss: t.lastLoss,
214
- tokensProcessed: a,
215
- duration: t.trainingDuration
216
- }, r && t.gradients) {
217
- const e = /* @__PURE__ */ new Map();
218
- for (const [h, g] of Object.entries(t.gradients))
219
- e.set(h, await N(g)), g.dispose();
220
- o.gradientMetrics = e;
221
- }
222
- if (l)
223
- try {
224
- const e = await l.evaluate(5);
225
- Array.isArray(e) ? o.validationMetrics = { loss: e[0].loss, accuracy: e[0].accuracy } : (t.validationLosses.push(e.loss), o.validationMetrics = {
226
- accuracy: e.accuracy,
227
- loss: e.loss,
228
- perplexity: this.metrics.has("perplexity") ? Math.exp(e.loss) : void 0
229
- });
230
- } catch (e) {
231
- console.error("Validation error:", e);
232
- }
233
- d && await d(o), t.logStartTime = Date.now();
234
- }
235
- async trainOnDataset(s, i, n) {
236
- const { logInterval: l = 10, maxEpochs: d = 1 / 0 } = {
237
- ...z,
238
- ...i
239
- }, r = d * (i?.epochSteps || 1e3);
240
- i.metrics && this.setMetrics(i.metrics);
241
- const m = Date.now(), t = this.createEmptyState();
242
- this.lastState = t, await this.dummyPass(), i?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new v())), this.running = !0, t.logStartTime = m;
243
- const c = n ? new S(this.model, n, this.maskedLoss) : void 0, a = await s.iterator();
244
- try {
245
- for (; this.running; ) {
246
- const o = await a.next();
247
- if (o.done) break;
248
- const e = o.value, h = t.step % l === 0, g = (i?.metrics?.includes("gradientStatistics") || !1) && h, f = this.trainStep(t, e, !1, g);
249
- if (i.debug) {
250
- const p = (await f.data())[0];
251
- if (isNaN(p) || !isFinite(p))
252
- throw console.error("Invalid loss value:", p), console.error("Batch xs:", await e.xs.array()), console.error("Batch ys:", await e.ys.array()), console.error("State:", t), new Error("Loss is NaN or Infinity");
253
- console.log(`Step ${t.step}: Loss = ${p}`);
254
- }
255
- e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++, h ? await this.performLogging(f, e.xs.shape[0], i, c) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= r && this.stop();
256
- }
257
- } catch (o) {
258
- throw console.error("Training error:", o), u(), o;
259
- }
260
- return u(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
261
- }
262
- }
263
- export {
264
- B as default
265
- };
@@ -1,26 +0,0 @@
1
- import { Tensor } from '@tensorflow/tfjs-core';
2
- import { Conversation, ITokeniser } from '../tokeniser/type';
3
- import { Dataset } from '@tensorflow/tfjs-data';
4
- export declare const PAGE_FACTOR = 8;
5
- export declare function flattenTokens(textData: Conversation[][], tokenizer: ITokeniser): Uint16Array;
6
- export declare function flattenTokensWithMask(textData: Conversation[][], tokenizer: ITokeniser): {
7
- tokens: Uint16Array;
8
- mask: Uint8Array;
9
- };
10
- export declare function shuffle(array: Uint32Array): Uint32Array;
11
- export interface DatasetState {
12
- shuffledIndexes: Uint32Array;
13
- step: number;
14
- }
15
- export declare class DatasetBuilder {
16
- tokenizer: ITokeniser;
17
- blockSize: number;
18
- constructor(tokenizer: ITokeniser, blockSize?: number);
19
- createTextDataset(flatTokens: Uint16Array, batchSize?: number, indexes?: Uint32Array, mask?: Uint8Array, ignoreIndex?: number): Promise<{
20
- dataset: Dataset<{
21
- xs: Tensor;
22
- ys: Tensor;
23
- }>;
24
- state: DatasetState;
25
- }>;
26
- }
@@ -1,86 +0,0 @@
1
- import { t as x } from "../index-CUXkjxiT.js";
2
- import { d as g, i as m } from "../dataset-CGGp1z9P.js";
3
- import "../index-Cp39cXWe.js";
4
- function p(e) {
5
- return g(async () => {
6
- const t = await e();
7
- return m(() => t.next());
8
- });
9
- }
10
- const I = 8;
11
- function z(e, t) {
12
- const r = e.map((c) => t.encodeConversation(c)).flat();
13
- return new Uint16Array(r);
14
- }
15
- function A(e, t) {
16
- const s = e.map((i) => t.encodeConversation(i, !1, !0));
17
- console.log("Tokenised Texts with Mask:", s);
18
- const r = s.map((i) => i.tokens).flat(), c = s.map((i) => i.mask).flat();
19
- return { tokens: new Uint16Array(r), mask: new Uint8Array(c.map((i) => i ? 1 : 0)) };
20
- }
21
- function u(e) {
22
- for (let t = e.length - 1; t > 0; t--) {
23
- const s = Math.floor(Math.random() * (t + 1));
24
- [e[t], e[s]] = [e[s], e[t]];
25
- }
26
- return e;
27
- }
28
- class S {
29
- tokenizer;
30
- blockSize;
31
- constructor(t, s = 128) {
32
- this.tokenizer = t, this.blockSize = s;
33
- }
34
- // Create dataset from text files
35
- async createTextDataset(t, s = 32, r, c, i = 65535) {
36
- if (t.length < this.blockSize + 1)
37
- throw new Error(`Not enough tokens (${t.length}) for block size ${this.blockSize}`);
38
- const o = {
39
- shuffledIndexes: new Uint32Array(t.length),
40
- step: 0
41
- };
42
- if (r)
43
- o.shuffledIndexes = r;
44
- else {
45
- o.shuffledIndexes = new Uint32Array(t.length);
46
- for (let n = 0; n < t.length; n++)
47
- o.shuffledIndexes[n] = n;
48
- u(o.shuffledIndexes);
49
- }
50
- const d = (function* () {
51
- for (; ; ) {
52
- const n = o.shuffledIndexes[o.step++];
53
- if (o.step >= o.shuffledIndexes.length && (o.step = 0, u(o.shuffledIndexes)), n + this.blockSize + 1 > t.length)
54
- continue;
55
- const a = new Int32Array(t.subarray(n, n + this.blockSize)), k = t.subarray(n + 1, n + this.blockSize + 1), l = new Int32Array(k);
56
- if (c) {
57
- let h = 0;
58
- for (let f = 0; f < l.length; f++)
59
- c[n + 1 + f] === 0 && (l[f] = i, h++);
60
- if (h === l.length)
61
- continue;
62
- }
63
- yield { xs: a, ys: l };
64
- }
65
- }).bind(this);
66
- return {
67
- dataset: p(d).batch(s).map((n) => {
68
- const a = n;
69
- return x(() => ({
70
- xs: a.xs.cast("int32"),
71
- ys: a.ys.cast("int32")
72
- // this.tf.oneHot(batchData.ys.cast('int32'), this.tokenizer.vocabSize),
73
- }));
74
- }).prefetch(2),
75
- // Smaller prefetch to reduce memory pressure
76
- state: o
77
- };
78
- }
79
- }
80
- export {
81
- S as DatasetBuilder,
82
- I as PAGE_FACTOR,
83
- z as flattenTokens,
84
- A as flattenTokensWithMask,
85
- u as shuffle
86
- };
@@ -1,19 +0,0 @@
1
- import { Dataset } from '@tensorflow/tfjs-data';
2
- import { TensorContainer } from '@tensorflow/tfjs-core';
3
- import { default as Model, ModelForwardAttributes } from '../models/model';
4
- interface Result {
5
- loss: number;
6
- accuracy: number;
7
- }
8
- export default class Evaluator {
9
- private model;
10
- private iterator?;
11
- private xs?;
12
- private ys?;
13
- private masked;
14
- constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer>, masked?: boolean);
15
- dispose(): void;
16
- private calculateBatchLoss;
17
- evaluate(maxBatches?: number): Promise<Result | Result[]>;
18
- }
19
- export {};
@@ -1,39 +0,0 @@
1
- import { t as d } from "../index-CUXkjxiT.js";
2
- import { calculateLoss as f, calculateAccuracy as p } from "./loss.js";
3
- class b {
4
- constructor(o, t, a) {
5
- this.model = o, this.masked = !!a, this.iterator = t.iterator();
6
- }
7
- iterator;
8
- xs;
9
- ys;
10
- masked = !1;
11
- dispose() {
12
- this.xs && this.xs.dispose(), this.ys && this.ys.dispose();
13
- }
14
- async calculateBatchLoss(o, t, a, r) {
15
- const [l, e] = d(() => {
16
- const s = this.model.forward({ training: !1 }, o), h = f(s, t, r, a), y = p(s, t);
17
- return s.dispose(), [h, y];
18
- }), u = await l.array(), n = await e.array(), c = u, i = n;
19
- return e.dispose(), l.dispose(), Array.isArray(c) ? c.map((s) => ({ loss: s, accuracy: i })) : { loss: c, accuracy: i };
20
- }
21
- async evaluate(o = 100) {
22
- let t = 0, a = 0, r = 0;
23
- if (this.iterator) {
24
- const l = await this.iterator;
25
- for (let e = 0; e < o; e++) {
26
- const u = await l.next();
27
- if (u.done) break;
28
- const n = u.value, { xs: c, ys: i } = n, s = await this.calculateBatchLoss(c, i, !1, this.masked);
29
- c.dispose(), i.dispose(), t += s.loss, a += s.accuracy, r++;
30
- }
31
- return { loss: t / r, accuracy: a / r };
32
- } else if (this.xs && this.ys)
33
- return this.calculateBatchLoss(this.xs, this.ys, !0, !0);
34
- throw new Error("No data available for evaluation");
35
- }
36
- }
37
- export {
38
- b as default
39
- };
@@ -1,12 +0,0 @@
1
- import { LRSchedulerConfig } from './types';
2
- export default class LRScheduler {
3
- protected learningRate: number;
4
- private config;
5
- private step;
6
- private startLearningRate;
7
- constructor(learningRate: number, config: LRSchedulerConfig);
8
- serializeConfig(): LRSchedulerConfig;
9
- updateConfig(newConfig: Partial<LRSchedulerConfig>, learningRate?: number): void;
10
- get lr(): number;
11
- getNextLR(): number;
12
- }
@@ -1,34 +0,0 @@
1
- class o {
2
- constructor(i, t) {
3
- this.learningRate = i, this.config = t, this.startLearningRate = i, t.step !== void 0 && (this.step = t.step);
4
- }
5
- step = 0;
6
- startLearningRate;
7
- serializeConfig() {
8
- return {
9
- ...this.config,
10
- step: this.step
11
- };
12
- }
13
- updateConfig(i, t) {
14
- this.config = { ...this.config, ...i }, t !== void 0 && (this.startLearningRate = t);
15
- }
16
- get lr() {
17
- return this.learningRate;
18
- }
19
- getNextLR() {
20
- const i = this.step;
21
- if (this.config.warmupSteps > 0 && i < this.config.warmupSteps) {
22
- const r = (i + 1) / this.config.warmupSteps, e = this.startLearningRate * r;
23
- return this.learningRate = e, this.step++, e;
24
- }
25
- const t = this.config.epochSteps * this.config.decayEpochs;
26
- if (i >= t || t <= this.config.warmupSteps)
27
- return this.learningRate = this.config.minLearningRate, this.step++, this.config.minLearningRate;
28
- const n = (i - this.config.warmupSteps) / (t - this.config.warmupSteps), a = 0.5 * (1 + Math.cos(Math.PI * n)), s = this.config.minLearningRate + a * (this.startLearningRate - this.config.minLearningRate);
29
- return this.learningRate = s, this.step++, s;
30
- }
31
- }
32
- export {
33
- o as default
34
- };