@genai-fi/nanogpt 0.20.0 → 0.20.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (433) hide show
  1. package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
  2. package/dist/DatasetBuilder-DgURD85T.js +712 -0
  3. package/dist/Generator.d.ts +82 -0
  4. package/dist/Generator.js +2 -0
  5. package/dist/RealDiv-DBu0FQqT.js +362 -0
  6. package/dist/Reshape-CABOPB9d.js +94 -0
  7. package/dist/Reshape-DqO3r8BC.js +17 -0
  8. package/dist/TeachableLLM.d.ts +70 -0
  9. package/dist/TeachableLLM.js +2 -0
  10. package/dist/Trainer.d.ts +43 -0
  11. package/dist/Trainer.js +2 -0
  12. package/dist/backend.d.ts +2 -0
  13. package/dist/backend.js +13 -0
  14. package/dist/backend_util-Cg-roD1p.js +399 -0
  15. package/dist/binary_op_util-CrYk9LXL.js +103 -0
  16. package/dist/checks/appendCache.d.ts +1 -0
  17. package/dist/checks/appendCache.js +55 -0
  18. package/dist/checks/attentionMask.d.ts +1 -0
  19. package/dist/checks/attentionMask.js +56 -0
  20. package/dist/checks/check.d.ts +9 -0
  21. package/dist/checks/check.js +32 -0
  22. package/dist/checks/gelu.d.ts +1 -0
  23. package/dist/checks/gelu.js +46 -0
  24. package/dist/checks/index.d.ts +26 -0
  25. package/dist/checks/index.js +28 -0
  26. package/dist/checks/matMulGelu.d.ts +1 -0
  27. package/dist/checks/matMulGelu.js +84 -0
  28. package/dist/checks/normRMS.d.ts +1 -0
  29. package/dist/checks/normRMS.js +28 -0
  30. package/dist/checks/normRMSGrad.d.ts +1 -0
  31. package/dist/checks/normRMSGrad.js +22 -0
  32. package/dist/checks/packUnpack.d.ts +1 -0
  33. package/dist/checks/packUnpack.js +46 -0
  34. package/dist/checks/qkv.d.ts +1 -0
  35. package/dist/checks/qkv.js +34 -0
  36. package/dist/checks/rope.d.ts +1 -0
  37. package/dist/checks/rope.js +30 -0
  38. package/dist/checks/weights.d.ts +14 -0
  39. package/dist/checks/weights.js +27 -0
  40. package/dist/chunk-BPntVaq0.js +23 -0
  41. package/dist/complex_util-CkazZsaH.js +60 -0
  42. package/dist/concat_util-CWDZCBlA.js +19 -0
  43. package/dist/data/docx.d.ts +2 -0
  44. package/dist/data/docx.js +3046 -0
  45. package/dist/data/pdf.d.ts +2 -0
  46. package/dist/data/pdf.js +17 -0
  47. package/dist/data/textLoader.d.ts +7 -0
  48. package/dist/data/textLoader.js +613 -0
  49. package/dist/dist-BewPQWjc.js +7572 -0
  50. package/dist/dist-DVmq73nz.js +8775 -0
  51. package/dist/dist-DXwIvKxl.js +896 -0
  52. package/dist/dist-VEU5mfO0.js +7545 -0
  53. package/dist/gelu-Bf1HW1RY.js +27 -0
  54. package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
  55. package/dist/inference/types.d.ts +16 -0
  56. package/dist/inference/types.js +0 -0
  57. package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
  58. package/dist/layers/BaseLayer.d.ts +44 -0
  59. package/dist/layers/BaseLayer.js +76 -0
  60. package/dist/layers/CausalSelfAttention.d.ts +39 -0
  61. package/dist/layers/CausalSelfAttention.js +99 -0
  62. package/dist/layers/LoRA.d.ts +14 -0
  63. package/dist/layers/LoRA.js +48 -0
  64. package/dist/layers/MLP.d.ts +17 -0
  65. package/dist/layers/MLP.js +34 -0
  66. package/dist/layers/PositionEmbedding.d.ts +8 -0
  67. package/dist/layers/PositionEmbedding.js +27 -0
  68. package/dist/layers/RMSNorm.d.ts +12 -0
  69. package/dist/layers/RMSNorm.js +20 -0
  70. package/dist/layers/RoPECache.d.ts +18 -0
  71. package/dist/layers/RoPECache.js +337 -0
  72. package/dist/layers/TiedEmbedding.d.ts +13 -0
  73. package/dist/layers/TiedEmbedding.js +32 -0
  74. package/dist/layers/TransformerBlock.d.ts +27 -0
  75. package/dist/layers/TransformerBlock.js +51 -0
  76. package/dist/layers/WeightStore.d.ts +20 -0
  77. package/dist/layers/WeightStore.js +69 -0
  78. package/dist/loader/load.d.ts +6 -0
  79. package/dist/loader/load.js +2 -0
  80. package/dist/loader/loadHF.d.ts +8 -0
  81. package/dist/loader/loadHF.js +2 -0
  82. package/dist/loader/loadTransformers.d.ts +4 -0
  83. package/dist/loader/loadTransformers.js +2 -0
  84. package/dist/loader/loadZipMeta.d.ts +3 -0
  85. package/dist/loader/loadZipMeta.js +16 -0
  86. package/dist/loader/newZipLoad.d.ts +3 -0
  87. package/dist/loader/newZipLoad.js +2 -0
  88. package/dist/loader/oldZipLoad.d.ts +9 -0
  89. package/dist/loader/oldZipLoad.js +2 -0
  90. package/dist/loader/save.d.ts +16 -0
  91. package/dist/loader/save.js +2 -0
  92. package/dist/loader/types.d.ts +68 -0
  93. package/dist/loader/types.js +0 -0
  94. package/dist/main-D5CbfCiV.js +13500 -0
  95. package/dist/main.d.ts +50 -0
  96. package/dist/main.js +16 -0
  97. package/dist/matMul16-BNfZSnNM.js +81 -0
  98. package/dist/matMulGelu-CPTntosE.js +162 -0
  99. package/dist/models/NanoGPTV1.d.ts +16 -0
  100. package/dist/models/NanoGPTV1.js +2 -0
  101. package/dist/models/NanoGPTV2.d.ts +16 -0
  102. package/dist/models/NanoGPTV2.js +2 -0
  103. package/dist/models/config.d.ts +27 -0
  104. package/dist/models/config.js +37 -0
  105. package/dist/models/factory.d.ts +3 -0
  106. package/dist/models/factory.js +2 -0
  107. package/dist/models/model.d.ts +44 -0
  108. package/dist/models/model.js +2 -0
  109. package/dist/ops/adamAdjust.d.ts +2 -0
  110. package/dist/ops/adamAdjust.js +18 -0
  111. package/dist/ops/adamMoments.d.ts +2 -0
  112. package/dist/ops/adamMoments.js +16 -0
  113. package/dist/ops/add16.d.ts +2 -0
  114. package/dist/ops/add16.js +12 -0
  115. package/dist/ops/appendCache.d.ts +2 -0
  116. package/dist/ops/appendCache.js +25 -0
  117. package/dist/ops/attentionMask.d.ts +2 -0
  118. package/dist/ops/attentionMask.js +16 -0
  119. package/dist/ops/concat16.d.ts +2 -0
  120. package/dist/ops/concat16.js +8 -0
  121. package/dist/ops/cpu/adamAdjust.d.ts +1 -0
  122. package/dist/ops/cpu/adamAdjust.js +16 -0
  123. package/dist/ops/cpu/adamMoments.d.ts +1 -0
  124. package/dist/ops/cpu/adamMoments.js +16 -0
  125. package/dist/ops/cpu/appendCache.d.ts +1 -0
  126. package/dist/ops/cpu/appendCache.js +65 -0
  127. package/dist/ops/cpu/attentionMask.d.ts +1 -0
  128. package/dist/ops/cpu/attentionMask.js +16 -0
  129. package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
  130. package/dist/ops/cpu/fusedSoftmax.js +22 -0
  131. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  132. package/dist/ops/cpu/gatherSub.js +12 -0
  133. package/dist/ops/cpu/gelu.d.ts +1 -0
  134. package/dist/ops/cpu/gelu.js +36 -0
  135. package/dist/ops/cpu/matMul16.d.ts +1 -0
  136. package/dist/ops/cpu/matMul16.js +14 -0
  137. package/dist/ops/cpu/matMulGelu.d.ts +1 -0
  138. package/dist/ops/cpu/matMulGelu.js +41 -0
  139. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  140. package/dist/ops/cpu/matMulMul.js +20 -0
  141. package/dist/ops/cpu/mulDropout.d.ts +1 -0
  142. package/dist/ops/cpu/mulDropout.js +20 -0
  143. package/dist/ops/cpu/normRMS.d.ts +1 -0
  144. package/dist/ops/cpu/normRMS.js +35 -0
  145. package/dist/ops/cpu/qkv.d.ts +5 -0
  146. package/dist/ops/cpu/qkv.js +73 -0
  147. package/dist/ops/cpu/rope.d.ts +6 -0
  148. package/dist/ops/cpu/rope.js +81 -0
  149. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  150. package/dist/ops/cpu/scatterSub.js +12 -0
  151. package/dist/ops/dot16.d.ts +2 -0
  152. package/dist/ops/dot16.js +29 -0
  153. package/dist/ops/dropout.d.ts +2 -0
  154. package/dist/ops/dropout.js +11 -0
  155. package/dist/ops/dropout16.d.ts +2 -0
  156. package/dist/ops/dropout16.js +22 -0
  157. package/dist/ops/gatherSub.d.ts +2 -0
  158. package/dist/ops/gatherSub.js +13 -0
  159. package/dist/ops/gelu.d.ts +3 -0
  160. package/dist/ops/gelu.js +2 -0
  161. package/dist/ops/globalNorm.d.ts +2 -0
  162. package/dist/ops/globalNorm.js +19 -0
  163. package/dist/ops/grads/add16.d.ts +1 -0
  164. package/dist/ops/grads/add16.js +27 -0
  165. package/dist/ops/grads/attentionMask.d.ts +1 -0
  166. package/dist/ops/grads/attentionMask.js +26 -0
  167. package/dist/ops/grads/dropout16.d.ts +1 -0
  168. package/dist/ops/grads/dropout16.js +1 -0
  169. package/dist/ops/grads/gelu.d.ts +2 -0
  170. package/dist/ops/grads/gelu.js +2 -0
  171. package/dist/ops/grads/matMul16.d.ts +2 -0
  172. package/dist/ops/grads/matMul16.js +2 -0
  173. package/dist/ops/grads/matMulGelu.d.ts +1 -0
  174. package/dist/ops/grads/matMulGelu.js +22 -0
  175. package/dist/ops/grads/mul16.d.ts +1 -0
  176. package/dist/ops/grads/mul16.js +1 -0
  177. package/dist/ops/grads/normRMS.d.ts +3 -0
  178. package/dist/ops/grads/normRMS.js +37 -0
  179. package/dist/ops/grads/pack16.d.ts +2 -0
  180. package/dist/ops/grads/pack16.js +2 -0
  181. package/dist/ops/grads/qkv.d.ts +3 -0
  182. package/dist/ops/grads/qkv.js +46 -0
  183. package/dist/ops/grads/rope.d.ts +2 -0
  184. package/dist/ops/grads/rope.js +2 -0
  185. package/dist/ops/grads/softmax16.d.ts +2 -0
  186. package/dist/ops/grads/softmax16.js +23 -0
  187. package/dist/ops/grads/unpack16.d.ts +2 -0
  188. package/dist/ops/grads/unpack16.js +2 -0
  189. package/dist/ops/grads/utils.d.ts +4 -0
  190. package/dist/ops/grads/utils.js +12 -0
  191. package/dist/ops/log.d.ts +0 -0
  192. package/dist/ops/log.js +1 -0
  193. package/dist/ops/matMul16.d.ts +15 -0
  194. package/dist/ops/matMul16.js +2 -0
  195. package/dist/ops/matMulGelu.d.ts +3 -0
  196. package/dist/ops/matMulGelu.js +20 -0
  197. package/dist/ops/matMulMul.d.ts +2 -0
  198. package/dist/ops/matMulMul.js +16 -0
  199. package/dist/ops/mul16.d.ts +2 -0
  200. package/dist/ops/mul16.js +43 -0
  201. package/dist/ops/mulDrop.d.ts +2 -0
  202. package/dist/ops/mulDrop.js +15 -0
  203. package/dist/ops/normRMS.d.ts +2 -0
  204. package/dist/ops/normRMS.js +22 -0
  205. package/dist/ops/pack16.d.ts +2 -0
  206. package/dist/ops/pack16.js +2 -0
  207. package/dist/ops/qkv.d.ts +2 -0
  208. package/dist/ops/qkv.js +16 -0
  209. package/dist/ops/reshape16.d.ts +2 -0
  210. package/dist/ops/reshape16.js +33 -0
  211. package/dist/ops/rope.d.ts +3 -0
  212. package/dist/ops/rope.js +2 -0
  213. package/dist/ops/scatterSub.d.ts +2 -0
  214. package/dist/ops/scatterSub.js +13 -0
  215. package/dist/ops/slice16.d.ts +2 -0
  216. package/dist/ops/slice16.js +11 -0
  217. package/dist/ops/softmax16.d.ts +2 -0
  218. package/dist/ops/softmax16.js +9 -0
  219. package/dist/ops/sub16.d.ts +2 -0
  220. package/dist/ops/sub16.js +11 -0
  221. package/dist/ops/sum16.d.ts +2 -0
  222. package/dist/ops/sum16.js +13 -0
  223. package/dist/ops/transpose16.d.ts +3 -0
  224. package/dist/ops/transpose16.js +32 -0
  225. package/dist/ops/unpack16.d.ts +2 -0
  226. package/dist/ops/unpack16.js +2 -0
  227. package/dist/ops/webgl/adamAdjust.d.ts +1 -0
  228. package/dist/ops/webgl/adamAdjust.js +82 -0
  229. package/dist/ops/webgl/adamMoments.d.ts +1 -0
  230. package/dist/ops/webgl/adamMoments.js +44 -0
  231. package/dist/ops/webgl/appendCache.d.ts +1 -0
  232. package/dist/ops/webgl/appendCache.js +53 -0
  233. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  234. package/dist/ops/webgl/attentionMask.js +64 -0
  235. package/dist/ops/webgl/dropout16.d.ts +1 -0
  236. package/dist/ops/webgl/dropout16.js +12 -0
  237. package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
  238. package/dist/ops/webgl/fusedSoftmax.js +70 -0
  239. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  240. package/dist/ops/webgl/gatherSub.js +28 -0
  241. package/dist/ops/webgl/gelu.d.ts +2 -0
  242. package/dist/ops/webgl/gelu.js +48 -0
  243. package/dist/ops/webgl/log.d.ts +17 -0
  244. package/dist/ops/webgl/log.js +14 -0
  245. package/dist/ops/webgl/matMul16.d.ts +1 -0
  246. package/dist/ops/webgl/matMul16.js +37 -0
  247. package/dist/ops/webgl/matMulGelu.d.ts +21 -0
  248. package/dist/ops/webgl/matMulGelu.js +2 -0
  249. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  250. package/dist/ops/webgl/matMulMul.js +24 -0
  251. package/dist/ops/webgl/mulDropout.d.ts +1 -0
  252. package/dist/ops/webgl/mulDropout.js +32 -0
  253. package/dist/ops/webgl/normRMS.d.ts +1 -0
  254. package/dist/ops/webgl/normRMS.js +114 -0
  255. package/dist/ops/webgl/qkv.d.ts +1 -0
  256. package/dist/ops/webgl/qkv.js +54 -0
  257. package/dist/ops/webgl/rope.d.ts +1 -0
  258. package/dist/ops/webgl/rope.js +72 -0
  259. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  260. package/dist/ops/webgl/scatterSub.js +28 -0
  261. package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
  262. package/dist/ops/webgpu/adamAdjust.js +77 -0
  263. package/dist/ops/webgpu/adamMoments.d.ts +1 -0
  264. package/dist/ops/webgpu/adamMoments.js +76 -0
  265. package/dist/ops/webgpu/add16.d.ts +1 -0
  266. package/dist/ops/webgpu/add16.js +14 -0
  267. package/dist/ops/webgpu/appendCache.d.ts +1 -0
  268. package/dist/ops/webgpu/appendCache.js +130 -0
  269. package/dist/ops/webgpu/attentionMask.d.ts +1 -0
  270. package/dist/ops/webgpu/attentionMask.js +42 -0
  271. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  272. package/dist/ops/webgpu/attentionMask32_program.js +62 -0
  273. package/dist/ops/webgpu/clipScale.d.ts +1 -0
  274. package/dist/ops/webgpu/clipScale.js +45 -0
  275. package/dist/ops/webgpu/concat16.d.ts +19 -0
  276. package/dist/ops/webgpu/concat16.js +111 -0
  277. package/dist/ops/webgpu/dropout16.d.ts +1 -0
  278. package/dist/ops/webgpu/dropout16.js +59 -0
  279. package/dist/ops/webgpu/gatherSub.d.ts +1 -0
  280. package/dist/ops/webgpu/gatherSub.js +52 -0
  281. package/dist/ops/webgpu/gelu.d.ts +14 -0
  282. package/dist/ops/webgpu/gelu.js +147 -0
  283. package/dist/ops/webgpu/index.d.ts +0 -0
  284. package/dist/ops/webgpu/index.js +26 -0
  285. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  286. package/dist/ops/webgpu/matMul16.js +70 -0
  287. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  288. package/dist/ops/webgpu/matMul16_program.js +303 -0
  289. package/dist/ops/webgpu/mul16.d.ts +1 -0
  290. package/dist/ops/webgpu/mul16.js +14 -0
  291. package/dist/ops/webgpu/norm2.d.ts +1 -0
  292. package/dist/ops/webgpu/norm2.js +46 -0
  293. package/dist/ops/webgpu/normRMS.d.ts +1 -0
  294. package/dist/ops/webgpu/normRMS.js +26 -0
  295. package/dist/ops/webgpu/normRMS16_program.d.ts +10 -0
  296. package/dist/ops/webgpu/normRMS16_program.js +28 -0
  297. package/dist/ops/webgpu/normRMS32_program.d.ts +10 -0
  298. package/dist/ops/webgpu/normRMS32_program.js +28 -0
  299. package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
  300. package/dist/ops/webgpu/normRMSGrad.js +225 -0
  301. package/dist/ops/webgpu/pack16.d.ts +1 -0
  302. package/dist/ops/webgpu/pack16.js +21 -0
  303. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  304. package/dist/ops/webgpu/pack16_program.js +93 -0
  305. package/dist/ops/webgpu/qkv.d.ts +1 -0
  306. package/dist/ops/webgpu/qkv.js +64 -0
  307. package/dist/ops/webgpu/rope.d.ts +1 -0
  308. package/dist/ops/webgpu/rope.js +163 -0
  309. package/dist/ops/webgpu/scatterSub.d.ts +1 -0
  310. package/dist/ops/webgpu/scatterSub.js +53 -0
  311. package/dist/ops/webgpu/slice16.d.ts +7 -0
  312. package/dist/ops/webgpu/slice16.js +74 -0
  313. package/dist/ops/webgpu/softmax16.d.ts +17 -0
  314. package/dist/ops/webgpu/softmax16.js +18 -0
  315. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  316. package/dist/ops/webgpu/softmax16_program.js +89 -0
  317. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  318. package/dist/ops/webgpu/softmax16_subgroup_program.js +70 -0
  319. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  320. package/dist/ops/webgpu/softmax16grad.js +31 -0
  321. package/dist/ops/webgpu/sub16.d.ts +1 -0
  322. package/dist/ops/webgpu/sub16.js +14 -0
  323. package/dist/ops/webgpu/sum16.d.ts +1 -0
  324. package/dist/ops/webgpu/sum16.js +29 -0
  325. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  326. package/dist/ops/webgpu/transpose16.js +37 -0
  327. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  328. package/dist/ops/webgpu/transpose16_program.js +51 -0
  329. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  330. package/dist/ops/webgpu/transpose16_shared_program.js +79 -0
  331. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  332. package/dist/ops/webgpu/unpack16.js +60 -0
  333. package/dist/ops/webgpu/utils/binary_op.d.ts +35 -0
  334. package/dist/ops/webgpu/utils/binary_op.js +141 -0
  335. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  336. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  337. package/dist/ops/webgpu/utils/reductions.d.ts +43 -0
  338. package/dist/ops/webgpu/utils/reductions.js +263 -0
  339. package/dist/pack16-Ck-spx_F.js +39 -0
  340. package/dist/patches/webgpu_backend.d.ts +18 -0
  341. package/dist/patches/webgpu_backend.js +43 -0
  342. package/dist/patches/webgpu_base.d.ts +21 -0
  343. package/dist/patches/webgpu_base.js +22 -0
  344. package/dist/patches/webgpu_program.d.ts +36 -0
  345. package/dist/patches/webgpu_program.js +293 -0
  346. package/dist/pdf-UoDqCYzz.js +16726 -0
  347. package/dist/picomatch-3tUnMMbd.js +1063 -0
  348. package/dist/rope-CbeGlsV8.js +25 -0
  349. package/dist/selu_util-zkAx5doH.js +24 -0
  350. package/dist/shared-D1coEFea.js +1314 -0
  351. package/dist/shared-DOgWaqvL.js +5 -0
  352. package/dist/slice_util-Dgb3ANWI.js +208 -0
  353. package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
  354. package/dist/tokeniser/BaseTokeniser.d.ts +33 -0
  355. package/dist/tokeniser/BaseTokeniser.js +2 -0
  356. package/dist/tokeniser/CharTokeniser.d.ts +24 -0
  357. package/dist/tokeniser/CharTokeniser.js +92 -0
  358. package/dist/tokeniser/bpe.d.ts +28 -0
  359. package/dist/tokeniser/bpe.js +170 -0
  360. package/dist/tokeniser/messages.d.ts +61 -0
  361. package/dist/tokeniser/messages.js +0 -0
  362. package/dist/tokeniser/type.d.ts +34 -0
  363. package/dist/tokeniser/type.js +0 -0
  364. package/dist/training/AdamW.d.ts +36 -0
  365. package/dist/training/AdamW.js +128 -0
  366. package/dist/training/BasicTrainer.d.ts +63 -0
  367. package/dist/training/BasicTrainer.js +265 -0
  368. package/dist/training/DatasetBuilder.d.ts +26 -0
  369. package/dist/training/DatasetBuilder.js +2 -0
  370. package/dist/training/Evaluator.d.ts +19 -0
  371. package/dist/training/Evaluator.js +48 -0
  372. package/dist/training/LRScheduler.d.ts +12 -0
  373. package/dist/training/LRScheduler.js +38 -0
  374. package/dist/training/PreTrainer.d.ts +11 -0
  375. package/dist/training/PreTrainer.js +22 -0
  376. package/dist/training/SFTTrainer.d.ts +12 -0
  377. package/dist/training/SFTTrainer.js +24 -0
  378. package/dist/training/loss.d.ts +3 -0
  379. package/dist/training/loss.js +19 -0
  380. package/dist/training/orthoGrad.d.ts +2 -0
  381. package/dist/training/orthoGrad.js +10 -0
  382. package/dist/training/sparseCrossEntropy.d.ts +7 -0
  383. package/dist/training/sparseCrossEntropy.js +47 -0
  384. package/dist/training/tasks/ConversationTask.d.ts +18 -0
  385. package/dist/training/tasks/ConversationTask.js +38 -0
  386. package/dist/training/tasks/PretrainingTask.d.ts +17 -0
  387. package/dist/training/tasks/PretrainingTask.js +42 -0
  388. package/dist/training/tasks/StartSentenceTask.d.ts +18 -0
  389. package/dist/training/tasks/StartSentenceTask.js +45 -0
  390. package/dist/training/tasks/Task.d.ts +22 -0
  391. package/dist/training/tasks/Task.js +55 -0
  392. package/dist/training/tasks/splitter.d.ts +5 -0
  393. package/dist/training/tasks/splitter.js +18 -0
  394. package/dist/training/types.d.ts +78 -0
  395. package/dist/training/types.js +0 -0
  396. package/dist/training/validation.d.ts +17 -0
  397. package/dist/training/validation.js +2 -0
  398. package/dist/utilities/arrayClose.d.ts +1 -0
  399. package/dist/utilities/arrayClose.js +16 -0
  400. package/dist/utilities/datasetID.d.ts +2 -0
  401. package/dist/utilities/datasetID.js +18 -0
  402. package/dist/utilities/dummy.d.ts +9 -0
  403. package/dist/utilities/dummy.js +36 -0
  404. package/dist/utilities/multinomialCPU.d.ts +2 -0
  405. package/dist/utilities/multinomialCPU.js +9 -0
  406. package/dist/utilities/naming.d.ts +4 -0
  407. package/dist/utilities/naming.js +0 -0
  408. package/dist/utilities/packed.d.ts +4 -0
  409. package/dist/utilities/packed.js +13 -0
  410. package/dist/utilities/parameters.d.ts +11 -0
  411. package/dist/utilities/parameters.js +38 -0
  412. package/dist/utilities/performance.d.ts +2 -0
  413. package/dist/utilities/performance.js +16 -0
  414. package/dist/utilities/profile.d.ts +17 -0
  415. package/dist/utilities/profile.js +33 -0
  416. package/dist/utilities/safetensors.d.ts +3 -0
  417. package/dist/utilities/safetensors.js +53 -0
  418. package/dist/utilities/sentences.d.ts +5 -0
  419. package/dist/utilities/sentences.js +32 -0
  420. package/dist/utilities/tokenParse.d.ts +1 -0
  421. package/dist/utilities/tokenParse.js +17 -0
  422. package/dist/utilities/topP.d.ts +1 -0
  423. package/dist/utilities/topP.js +12 -0
  424. package/dist/utilities/waitForModel.d.ts +2 -0
  425. package/dist/utilities/waitForModel.js +12 -0
  426. package/dist/utilities/weights.d.ts +12 -0
  427. package/dist/utilities/weights.js +40 -0
  428. package/dist/utilities/yielder.d.ts +1 -0
  429. package/dist/utilities/yielder.js +7 -0
  430. package/dist/webgpu-Dt7BMzWz.js +525 -0
  431. package/dist/webgpu_program-WOyIVMlZ.js +392 -0
  432. package/dist/webgpu_util-B_F3SShA.js +106 -0
  433. package/package.json +1 -1
@@ -0,0 +1,16 @@
1
+ import { Conversation } from '../tokeniser/type';
2
+ export interface GeneratorConversation extends Conversation {
3
+ _completed?: boolean;
4
+ _timestamp?: number;
5
+ }
6
+ export interface GenerateOptions {
7
+ temperature?: number;
8
+ topK?: number;
9
+ topP?: number;
10
+ usePadding?: boolean;
11
+ attentionScores?: boolean;
12
+ includeProbabilities?: boolean;
13
+ embeddings?: 'embedding' | 'logits' | 'softmax' | 'all';
14
+ targets?: number[];
15
+ loraName?: string;
16
+ }
File without changes
@@ -0,0 +1,229 @@
1
+ import { Ci as e, Di as t, In as n, Ms as r, Ya as i, ca as a, ko as o, oc as s, to as c } from "./dist-BewPQWjc.js";
2
+ import { r as l } from "./backend_util-Cg-roD1p.js";
3
+ import { a as u, o as d } from "./gpgpu_math-DvLcCH6u.js";
4
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/packing_util.js
5
+ function f(e, t) {
6
+ return [
7
+ "x",
8
+ "y",
9
+ "z",
10
+ "w",
11
+ "u",
12
+ "v"
13
+ ].slice(0, t).map((t) => `${e}.${t}`);
14
+ }
15
+ function p(e, t) {
16
+ return t === 1 ? [e] : f(e, t);
17
+ }
18
+ function m(e, t) {
19
+ if (e === 1) return "rc";
20
+ let n = "";
21
+ for (let r = 0; r < e; r++) n += t[r], r < e - 1 && (n += ",");
22
+ return n;
23
+ }
24
+ //#endregion
25
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/unaryop_gpu.js
26
+ var h = class {
27
+ constructor(e, t) {
28
+ this.variableNames = ["A"], this.outputShape = e, this.enableShapeUniforms = u(this.outputShape.length), this.userCode = `
29
+ float unaryOperation(float x) {
30
+ ${t}
31
+ }
32
+
33
+ void main() {
34
+ float x = getAAtOutCoords();
35
+ float y = unaryOperation(x);
36
+
37
+ setOutput(y);
38
+ }
39
+ `;
40
+ }
41
+ }, g = "if (isnan(x)) return x;", _ = "return x;", v = "return abs(x);", y = "return (x >= 0.0) ? x : (exp(x) - 1.0);", b = g + "\n return (x < 0.0) ? 0.0 : x;\n", x = g + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n", S = "return x;", C = "return 1.0 / (1.0 + exp(-1.0 * x));", w = "return x;", T = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n", E = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n", D = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n", O = "return 1.0 / (1.0 + exp(-1.0 * x));", k = class {
42
+ constructor(e, t) {
43
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = e, this.enableShapeUniforms = u(this.outputShape.length), this.userCode = `
44
+ vec4 unaryOperation(vec4 x) {
45
+ ${t}
46
+ }
47
+
48
+ void main() {
49
+ vec4 x = getAAtOutCoords();
50
+ vec4 y = unaryOperation(x);
51
+
52
+ setOutput(y);
53
+ }
54
+ `;
55
+ }
56
+ }, A = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n", j = class {
57
+ constructor(e, t, r) {
58
+ this.variableNames = ["A", "B"], this.outputShape = n(t, r), this.enableShapeUniforms = u(this.outputShape.length), this.userCode = `
59
+ float binaryOperation(float a, float b) {
60
+ ${e}
61
+ }
62
+
63
+ void main() {
64
+ float a = getAAtOutCoords();
65
+ float b = getBAtOutCoords();
66
+ setOutput(binaryOperation(a, b));
67
+ }
68
+ `;
69
+ }
70
+ }, M = "\n result.r = isNaN.r ? NAN : result.r;\n result.g = isNaN.g ? NAN : result.g;\n result.b = isNaN.b ? NAN : result.b;\n result.a = isNaN.a ? NAN : result.a;\n", N = class {
71
+ constructor(e, t, r, i = !1) {
72
+ this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = n(t, r);
73
+ let a = this.outputShape.length;
74
+ this.enableShapeUniforms = u(a);
75
+ let o = "";
76
+ if (i) if (a === 0 || s(this.outputShape) === 1) o = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n ";
77
+ else if (o = `
78
+ ${d(a)} coords = getOutputCoords();
79
+ `, a === 1) this.enableShapeUniforms ? o += "\n result.y = (coords + 1) >= outShape ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n " : o += `
80
+ result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
81
+ result.z = 0.;
82
+ result.w = 0.;
83
+ `;
84
+ else {
85
+ let e = p("coords", a);
86
+ this.enableShapeUniforms ? o += `
87
+ bool nextRowOutOfBounds =
88
+ (${e[a - 2]} + 1) >= outShape[${a} - 2];
89
+ bool nextColOutOfBounds =
90
+ (${e[a - 1]} + 1) >= outShape[${a} - 1];
91
+ result.y = nextColOutOfBounds ? 0. : result.y;
92
+ result.z = nextRowOutOfBounds ? 0. : result.z;
93
+ result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
94
+ ` : o += `
95
+ bool nextRowOutOfBounds =
96
+ (${e[a - 2]} + 1) >= ${this.outputShape[a - 2]};
97
+ bool nextColOutOfBounds =
98
+ (${e[a - 1]} + 1) >= ${this.outputShape[a - 1]};
99
+ result.y = nextColOutOfBounds ? 0. : result.y;
100
+ result.z = nextRowOutOfBounds ? 0. : result.z;
101
+ result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
102
+ `;
103
+ }
104
+ this.userCode = `
105
+ vec4 binaryOperation(vec4 a, vec4 b) {
106
+ ${e}
107
+ }
108
+
109
+ void main() {
110
+ vec4 a = getAAtOutCoords();
111
+ vec4 b = getBAtOutCoords();
112
+
113
+ vec4 result = binaryOperation(a, b);
114
+ ${o}
115
+
116
+ setOutput(result);
117
+ }
118
+ `;
119
+ }
120
+ };
121
+ //#endregion
122
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Identity.js
123
+ function P(e) {
124
+ let { inputs: t, backend: n } = e, { x: r } = t;
125
+ return n.incRef(r.dataId), {
126
+ dataId: r.dataId,
127
+ shape: r.shape,
128
+ dtype: r.dtype
129
+ };
130
+ }
131
+ var F = {
132
+ kernelName: i,
133
+ backendName: "webgl",
134
+ kernelFunc: P
135
+ };
136
+ //#endregion
137
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Complex.js
138
+ function I(e) {
139
+ let { inputs: t, backend: n } = e, { real: r, imag: i } = t, a = n.makeTensorInfo(r.shape, "complex64"), o = n.texData.get(a.dataId);
140
+ return o.complexTensorInfos = {
141
+ real: P({
142
+ inputs: { x: r },
143
+ backend: n
144
+ }),
145
+ imag: P({
146
+ inputs: { x: i },
147
+ backend: n
148
+ })
149
+ }, a;
150
+ }
151
+ var L = {
152
+ kernelName: a,
153
+ backendName: "webgl",
154
+ kernelFunc: I
155
+ }, R = "return (a < 0.) ? b * a : a;", z = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
156
+ function B(e) {
157
+ let { inputs: n, backend: i, attrs: a } = e, { x: o } = n, { alpha: s } = a, c = i.makeTensorInfo([], "float32", t(s, "float32")), l = r().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new N(z, o.shape, c.shape) : new j(R, o.shape, c.shape), u = i.runWebGLProgram(l, [o, c], "float32");
158
+ return i.disposeIntermediateTensorInfo(c), u;
159
+ }
160
+ var V = {
161
+ kernelName: c,
162
+ backendName: "webgl",
163
+ kernelFunc: B
164
+ }, H = "return (a < 0.) ? b * a : a;", U = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
165
+ function W(e) {
166
+ let { inputs: t, backend: n } = e, { x: i, alpha: a } = t, o = r().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new N(U, i.shape, a.shape) : new j(H, i.shape, a.shape);
167
+ return n.runWebGLProgram(o, [i, a], "float32");
168
+ }
169
+ var G = {
170
+ kernelName: o,
171
+ backendName: "webgl",
172
+ kernelFunc: W
173
+ }, K = "if (isnan(x)) return x;";
174
+ function q({ opSnippet: e, packedOpSnippet: t, cpuKernelImpl: n, dtype: i }) {
175
+ return ({ inputs: a, backend: o }) => {
176
+ let { x: s } = a, c = o, l = i || s.dtype;
177
+ if (c.shouldExecuteOnCPU([s]) && n != null) {
178
+ let e = n(c.texData.get(s.dataId).values, l);
179
+ return c.makeTensorInfo(s.shape, l, e);
180
+ }
181
+ let u = r().getBool("WEBGL_PACK_UNARY_OPERATIONS") && t != null, d;
182
+ return d = u ? new k(s.shape, t) : new h(s.shape, e), c.runWebGLProgram(d, [s], l);
183
+ };
184
+ }
185
+ function J({ opSnippet: t, packedOpSnippet: n, checkOutOfBounds: i = !1, supportsComplex: a = !1, cpuKernelImpl: o, dtype: s }) {
186
+ return ({ inputs: c, backend: u }) => {
187
+ let { a: d, b: f } = c, p = u;
188
+ if (a && d.dtype === "complex64") {
189
+ let n = p.texData.get(d.dataId), r = p.texData.get(f.dataId), [i, a] = [[n.complexTensorInfos.real, r.complexTensorInfos.real], [n.complexTensorInfos.imag, r.complexTensorInfos.imag]].map((n) => {
190
+ let [r, i] = n, a = {
191
+ dataId: r.dataId,
192
+ dtype: r.dtype,
193
+ shape: d.shape
194
+ }, o = {
195
+ dataId: i.dataId,
196
+ dtype: i.dtype,
197
+ shape: f.shape
198
+ }, s = new j(t, d.shape, f.shape);
199
+ return p.runWebGLProgram(s, [a, o], e(r.dtype, i.dtype));
200
+ }), o = I({
201
+ inputs: {
202
+ real: i,
203
+ imag: a
204
+ },
205
+ backend: p
206
+ });
207
+ return p.disposeIntermediateTensorInfo(i), p.disposeIntermediateTensorInfo(a), o;
208
+ }
209
+ let m = s || e(d.dtype, f.dtype);
210
+ if ((d.dtype === "string" || f.dtype === "string" || p.shouldExecuteOnCPU([d, f])) && o != null) {
211
+ let e = p.texData.get(d.dataId).values, t = p.texData.get(f.dataId).values, n = d.dtype === "string" ? l(e) : e, r = d.dtype === "string" ? l(t) : t, [i, a] = o(d.shape, f.shape, n, r, m), s = p.makeTensorInfo(a, m), c = p.texData.get(s.dataId);
212
+ return c.values = i, s;
213
+ }
214
+ let h = r().getBool("WEBGL_PACK_BINARY_OPERATIONS") && n != null, g;
215
+ return g = h ? new N(n, d.shape, f.shape, i) : new j(t, d.shape, f.shape), p.runWebGLProgram(g, [d, f], m);
216
+ };
217
+ }
218
+ function Y(e, t = !1) {
219
+ if (e === "linear") return t ? w : _;
220
+ if (e === "relu") return t ? E : b;
221
+ if (e === "elu") return t ? T : y;
222
+ if (e === "relu6") return t ? D : x;
223
+ if (e === "prelu") return t ? U : H;
224
+ if (e === "leakyrelu") return t ? z : R;
225
+ if (e === "sigmoid") return t ? O : C;
226
+ throw Error(`Activation ${e} has not been implemented for the WebGL backend.`);
227
+ }
228
+ //#endregion
229
+ export { f as S, g as _, G as a, p as b, L as c, N as d, M as f, v as g, k as h, q as i, P as l, A as m, J as n, V as o, j as p, Y as r, I as s, K as t, F as u, S as v, m as x, h as y };
@@ -0,0 +1,44 @@
1
+ import { GPTConfig } from '../../models/config';
2
+ import { default as MemoryProfiler } from '../../utilities/profile';
3
+ import { default as RoPECache } from './RoPECache';
4
+ import { Tensor, Variable } from '@tensorflow/tfjs-core';
5
+ import { default as WeightStore } from './WeightStore';
6
+ export interface ForwardAttributes {
7
+ training: boolean;
8
+ checkpointing?: boolean;
9
+ mixedPrecision?: boolean;
10
+ ropeCache?: RoPECache;
11
+ outputEmbeddings?: boolean;
12
+ embeddings?: {
13
+ name: string;
14
+ tensor: Tensor;
15
+ }[];
16
+ dropout?: number;
17
+ layerDrop?: number;
18
+ }
19
+ export default abstract class BaseLayer<ATTR extends ForwardAttributes = ForwardAttributes, CONFIG extends GPTConfig = GPTConfig> {
20
+ readonly parent?: BaseLayer;
21
+ readonly config: CONFIG;
22
+ weightStore: WeightStore;
23
+ readonly children: BaseLayer[];
24
+ private profiler?;
25
+ private ownVariables;
26
+ constructor(config: CONFIG, parent?: BaseLayer);
27
+ getProfiler(): MemoryProfiler | undefined;
28
+ setProfiler(profiler: MemoryProfiler | null): void;
29
+ startMemory(): void;
30
+ endMemory(label: string): void;
31
+ addVariable(name: string, variable?: Variable): void;
32
+ addChildVariable(name: string): void;
33
+ get variables(): Variable[];
34
+ get trainableVariables(): Variable[];
35
+ getVariable(name: string): Tensor;
36
+ hasVariable(name: string): boolean;
37
+ setVariable(name: string, variable: Variable): void;
38
+ dispose(): void;
39
+ protected build(): void;
40
+ abstract forward(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
41
+ call(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
42
+ callCheckpoint(attrs: ATTR, ...x: Tensor[]): Tensor;
43
+ private checkpointingFn;
44
+ }
@@ -0,0 +1,76 @@
1
+ import { Gt as e, Wt as t, ii as n } from "../dist-BewPQWjc.js";
2
+ import r from "./WeightStore.js";
3
+ //#region lib/layers/BaseLayer.ts
4
+ var i = class {
5
+ parent;
6
+ config;
7
+ weightStore;
8
+ children = [];
9
+ profiler;
10
+ ownVariables = /* @__PURE__ */ new Set();
11
+ constructor(e, t) {
12
+ this.config = e, this.parent = t, this.parent ? (this.parent.children.push(this), this.weightStore = this.parent.weightStore) : this.weightStore = new r();
13
+ }
14
+ getProfiler() {
15
+ return this.profiler;
16
+ }
17
+ setProfiler(e) {
18
+ this.profiler = e || void 0, this.children.forEach((t) => {
19
+ t.setProfiler(e);
20
+ });
21
+ }
22
+ startMemory() {
23
+ this.profiler?.startMemory();
24
+ }
25
+ endMemory(e) {
26
+ this.profiler?.endMemory(e);
27
+ }
28
+ addVariable(e, t) {
29
+ this.weightStore.addVariable(e, t), this.ownVariables.add(e), this.parent && this.parent.addChildVariable(e);
30
+ }
31
+ addChildVariable(e) {
32
+ this.ownVariables.add(e);
33
+ }
34
+ get variables() {
35
+ return this.weightStore.variables;
36
+ }
37
+ get trainableVariables() {
38
+ return this.weightStore.trainableVariables.filter((e) => this.ownVariables.has(e.name));
39
+ }
40
+ getVariable(e) {
41
+ return this.weightStore.getVariable(e);
42
+ }
43
+ hasVariable(e) {
44
+ return this.weightStore.hasVariable(e);
45
+ }
46
+ setVariable(e, t) {
47
+ this.weightStore.setVariable(e, t);
48
+ }
49
+ dispose() {
50
+ this.weightStore.dispose();
51
+ }
52
+ build() {}
53
+ call(e, ...t) {
54
+ return this.build(), this.forward(e, ...t);
55
+ }
56
+ callCheckpoint(e, ...t) {
57
+ return this.build(), this.checkpointingFn(e, ...t);
58
+ }
59
+ checkpointingFn(r, ...i) {
60
+ let a = this.trainableVariables;
61
+ return t((...t) => {
62
+ let o = t[t.length - 1], s = t.slice(0, i.length), c = this.forward(r, ...s);
63
+ return o(s), {
64
+ value: c,
65
+ gradFunc: (t, i) => {
66
+ let o = n().state.activeTape;
67
+ n().state.activeTape = [];
68
+ let c = e((...e) => this.forward(r, ...e.slice(0, s.length)))([...i, ...a], t);
69
+ return n().state.activeTape = o, c;
70
+ }
71
+ };
72
+ })(...i, ...a);
73
+ }
74
+ };
75
+ //#endregion
76
+ export { i as default };
@@ -0,0 +1,39 @@
1
+ import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
2
+ import { Tensor } from '@tensorflow/tfjs-core';
3
+ import { GPTConfig } from '../../models/config';
4
+ export interface KVCache {
5
+ k?: Tensor;
6
+ v?: Tensor;
7
+ length: number;
8
+ cumulativeLength: number;
9
+ }
10
+ export interface AttentionScores {
11
+ meanOfHeads?: boolean;
12
+ attentionOut?: Tensor[];
13
+ }
14
+ interface AttentionForwardAttributes extends ForwardAttributes {
15
+ attentionScores?: AttentionScores;
16
+ pastKV?: KVCache;
17
+ seed?: number;
18
+ ropePositionOffset?: number;
19
+ }
20
+ export interface CausalSelfAttentionConfig {
21
+ useQKNorm?: boolean;
22
+ }
23
+ export default class CausalSelfAttention extends BaseLayer<AttentionForwardAttributes> {
24
+ private readonly attentionConfig;
25
+ private divisor;
26
+ private index;
27
+ private units;
28
+ private projUnits;
29
+ private ATTN;
30
+ private PROJ;
31
+ constructor(index: number, config: GPTConfig, attentionConfig: CausalSelfAttentionConfig, parent?: BaseLayer);
32
+ protected build(): void;
33
+ private getAttentionScores;
34
+ private getQKV;
35
+ private getOutputProjection;
36
+ private updateCache;
37
+ forward(attr: AttentionForwardAttributes, x: Tensor): Tensor;
38
+ }
39
+ export {};
@@ -0,0 +1,99 @@
1
+ import { T as e, di as t, oi as n, pt as r } from "../dist-BewPQWjc.js";
2
+ import { isPackedTensor as i } from "../utilities/packed.js";
3
+ import { transpose16 as a } from "../ops/transpose16.js";
4
+ import { reshape16 as o } from "../ops/reshape16.js";
5
+ import { t as s } from "../matMul16-BNfZSnNM.js";
6
+ import { t as c } from "../pack16-Ck-spx_F.js";
7
+ import { attentionMask as l } from "../ops/attentionMask.js";
8
+ import u from "./BaseLayer.js";
9
+ import { t as d } from "../rope-CbeGlsV8.js";
10
+ import { appendCache as f } from "../ops/appendCache.js";
11
+ import { softmax16 as p } from "../ops/softmax16.js";
12
+ import { dot16 as m } from "../ops/dot16.js";
13
+ import { qkv as h } from "../ops/qkv.js";
14
+ import { normRMS as g } from "../ops/normRMS.js";
15
+ import { dropout16 as _ } from "../ops/dropout16.js";
16
+ //#region lib/layers/CausalSelfAttention.ts
17
+ var v = class extends u {
18
+ attentionConfig;
19
+ divisor;
20
+ index;
21
+ units;
22
+ projUnits;
23
+ ATTN;
24
+ PROJ;
25
+ constructor(e, t, n, r) {
26
+ super(t, r), this.attentionConfig = n, this.index = e, this.units = t.nEmbed * 3, this.projUnits = t.nEmbed, this.ATTN = `block_${this.index}_cAttn`, this.PROJ = `block_${this.index}_cProj`, this.addVariable(this.ATTN), this.addVariable(this.PROJ), this.divisor = 1 / Math.sqrt(t.nEmbed / t.nHead);
27
+ }
28
+ build() {
29
+ this.hasVariable(this.ATTN) === !1 && this.setVariable(this.ATTN, e(r([this.config.nEmbed, this.units], 0, .02), !0, this.ATTN)), this.hasVariable(this.PROJ) === !1 && this.setVariable(this.PROJ, e(r([this.projUnits, this.config.nEmbed], 0, .02), !0, this.PROJ));
30
+ }
31
+ getAttentionScores(e, t, n) {
32
+ let r = l(e, t, this.divisor, n), i = p(r);
33
+ return r.dispose(), i;
34
+ }
35
+ getQKV(e) {
36
+ let t = i(e) ? c(this.getVariable(this.ATTN)) : this.getVariable(this.ATTN), n = h(e, t, this.config.nHead);
37
+ return i(e) && t.dispose(), n;
38
+ }
39
+ getOutputProjection(e) {
40
+ let t = e.shape[0], n = e.shape[2], r = this.config.nEmbed, s = i(e), l = a(e, [
41
+ 0,
42
+ 2,
43
+ 1,
44
+ 3
45
+ ]), u = o(l, [
46
+ t,
47
+ n,
48
+ s ? r / 2 : r
49
+ ]);
50
+ l.dispose();
51
+ let d = s ? c(this.getVariable(this.PROJ)) : this.getVariable(this.PROJ), f = m(u, d);
52
+ return s && d.dispose(), u.dispose(), f;
53
+ }
54
+ updateCache(e, t, r) {
55
+ let i = this.config.blockSize, a = e.shape[2], o = r.length || 0, s = f(e, i, o, r.k);
56
+ e.dispose(), r.k && r.k.dispose();
57
+ let c = f(t, i, o, r.v);
58
+ t.dispose(), r.v && r.v.dispose();
59
+ let l = Math.min(o + a, i), u = r.cumulativeLength + a;
60
+ r.length = l, r.cumulativeLength = u, r.k = n(s), r.v = n(c);
61
+ }
62
+ forward(e, r) {
63
+ return t(() => {
64
+ this.startMemory();
65
+ let [t, i, a] = this.getQKV(r), o = e.pastKV ? e.pastKV.cumulativeLength : e.ropePositionOffset || 0, c = e.ropeCache, l = c ? d(t, c, o) : t, u = c ? d(i, c, o) : i, f = this.attentionConfig.useQKNorm ?? !1, p = f ? g(l) : l;
66
+ f && l.dispose();
67
+ let m = f ? g(u) : u;
68
+ f && u.dispose(), c && (t.dispose(), i.dispose());
69
+ let h = e.pastKV ? e.pastKV.length : 0;
70
+ e.pastKV && !e.training && this.updateCache(m, a, e.pastKV);
71
+ let v = e.pastKV?.k ? e.pastKV.k : m, y = e.pastKV?.v ? e.pastKV.v : a, b;
72
+ b = h > 0 ? this.getAttentionScores(p, v, h) : this.getAttentionScores(p, v), p.dispose(), e.pastKV || v.dispose();
73
+ let x = s(b, y), S = e.attentionScores !== void 0 && e.attentionScores.attentionOut !== void 0;
74
+ S || b.dispose(), e.pastKV || y.dispose();
75
+ let C = this.getOutputProjection(x);
76
+ if (x.dispose(), S && e.attentionScores && e.attentionScores.attentionOut !== void 0) {
77
+ let t = b.shape[1], r = b.shape[2];
78
+ e.attentionScores.attentionOut?.push(n(b.slice([
79
+ 0,
80
+ 0,
81
+ 0,
82
+ 0
83
+ ], [
84
+ 1,
85
+ -1,
86
+ -1,
87
+ -1
88
+ ]).reshape([
89
+ t,
90
+ r,
91
+ -1
92
+ ])));
93
+ }
94
+ return this.endMemory("CausalSelfAttention"), e.dropout && e.dropout > 0 ? _(C, e.dropout) : C;
95
+ });
96
+ }
97
+ };
98
+ //#endregion
99
+ export { v as default };
@@ -0,0 +1,14 @@
1
+ import { default as WeightStore } from './WeightStore';
2
+ export default class LoRA {
3
+ private weightStore;
4
+ readonly alpha: number;
5
+ readonly rank: number;
6
+ readonly variables: Set<string>;
7
+ private scale;
8
+ readonly name: string;
9
+ constructor(name: string, weightStore: WeightStore, alpha: number, rank: number, variables: string[]);
10
+ attach(): void;
11
+ merge(): void;
12
+ detach(): void;
13
+ dispose(): void;
14
+ }
@@ -0,0 +1,48 @@
1
+ import { a as e } from "../chunk-BPntVaq0.js";
2
+ import { T as t, _n as n, di as r, kt as i, pt as a } from "../dist-BewPQWjc.js";
3
+ import { t as o } from "../picomatch-3tUnMMbd.js";
4
+ //#region lib/layers/LoRA.ts
5
+ var s = /* @__PURE__ */ e(o(), 1), c = class {
6
+ weightStore;
7
+ alpha;
8
+ rank;
9
+ variables;
10
+ scale;
11
+ name;
12
+ constructor(e, r, o, c, l) {
13
+ this.name = e, this.weightStore = r, this.alpha = o, this.rank = c;
14
+ let u = (0, s.default)(l), d = r.variableNames.filter((e) => u(e) && !e.endsWith("_loraA") && !e.endsWith("_loraB"));
15
+ this.variables = new Set(d), this.scale = n(o / c), this.variables.forEach((e) => {
16
+ let n = this.weightStore.getRawVariable(e), [r, o] = n.shape, s = `${e}_${this.name}_loraA`, c = `${e}_${this.name}_loraB`;
17
+ if (n.shape.length !== 2) {
18
+ console.warn(`LoRA currently only supports 2D weight matrices. Variable ${e} has shape ${n.shape}`), this.variables.delete(e);
19
+ return;
20
+ }
21
+ this.weightStore.hasVariable(s) || this.weightStore.hasVariable(c) || (this.weightStore.addVariable(s, t(a([r, this.rank], 0, .02), !0, s)), this.weightStore.addVariable(c, t(i([this.rank, o]), !0, c)));
22
+ });
23
+ }
24
+ attach() {
25
+ if (this.weightStore.onWeightRead) throw Error("LoRA cannot be applied to a WeightStore that already has a onWeightRead hook.");
26
+ this.weightStore.onWeightRead = (e, t) => this.variables.has(e) ? r(() => {
27
+ let n = this.weightStore.getRawVariable(`${e}_${this.name}_loraA`), r = this.weightStore.getRawVariable(`${e}_${this.name}_loraB`);
28
+ return t.add(n.matMul(r).mul(this.scale));
29
+ }) : t, this.weightStore.setTrainable([`*_${this.name}_loraA`, `*_${this.name}_loraB`]);
30
+ }
31
+ merge() {
32
+ this.variables.forEach((e) => {
33
+ let t = this.weightStore.getRawVariable(e), n = this.weightStore.getRawVariable(`${e}_${this.name}_loraA`), i = this.weightStore.getRawVariable(`${e}_${this.name}_loraB`), a = r(() => t.add(n.matMul(i).mul(this.scale)));
34
+ t.assign(a), a.dispose();
35
+ });
36
+ }
37
+ detach() {
38
+ this.weightStore.onWeightRead = void 0, this.weightStore.setTrainable(["*"]);
39
+ }
40
+ dispose() {
41
+ this.detach(), this.scale.dispose(), this.variables.forEach((e) => {
42
+ let t = `${e}_${this.name}_loraA`, n = `${e}_${this.name}_loraB`;
43
+ this.weightStore.getRawVariable(t).dispose(), this.weightStore.getRawVariable(n).dispose(), this.weightStore.deleteVariable(t), this.weightStore.deleteVariable(n);
44
+ }), this.variables.clear();
45
+ }
46
+ };
47
+ //#endregion
48
+ export { c as default };
@@ -0,0 +1,17 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
3
+ import { GPTConfig } from '../../main';
4
+ export interface MLPConfig {
5
+ activation?: 'gelu' | 'relu2';
6
+ hiddenFactor?: number;
7
+ }
8
+ export default class MLP extends BaseLayer {
9
+ private index;
10
+ private hiddenUnits;
11
+ private MLPHIDDEN;
12
+ private MLPOUT;
13
+ private mlpConfig;
14
+ constructor(index: number, config: GPTConfig, mlpConfig: MLPConfig, parent?: BaseLayer);
15
+ protected build(): void;
16
+ forward(attr: ForwardAttributes, x: Tensor): Tensor;
17
+ }
@@ -0,0 +1,34 @@
1
+ import { T as e, di as t, pt as n } from "../dist-BewPQWjc.js";
2
+ import { reshape16 as r } from "../ops/reshape16.js";
3
+ import { t as i } from "../matMul16-BNfZSnNM.js";
4
+ import a from "./BaseLayer.js";
5
+ import { dropout16 as o } from "../ops/dropout16.js";
6
+ //#region lib/layers/MLP.ts
7
+ var s = class extends a {
8
+ index;
9
+ hiddenUnits;
10
+ MLPHIDDEN;
11
+ MLPOUT;
12
+ mlpConfig;
13
+ constructor(e, t, n, r) {
14
+ super(t, r), this.index = e, this.mlpConfig = n, this.hiddenUnits = (n.hiddenFactor ?? t.mlpFactor) * t.nEmbed, this.MLPHIDDEN = `block_${this.index}_mlpHidden`, this.MLPOUT = `block_${this.index}_mlpOut`, this.addVariable(this.MLPHIDDEN), this.addVariable(this.MLPOUT);
15
+ }
16
+ build() {
17
+ this.hasVariable(this.MLPHIDDEN) === !1 && this.setVariable(this.MLPHIDDEN, e(n([this.config.nEmbed, this.hiddenUnits], 0, .02), !0, this.MLPHIDDEN)), this.hasVariable(this.MLPOUT) === !1 && this.setVariable(this.MLPOUT, e(n([this.hiddenUnits, this.config.nEmbed], 0, .02 / Math.sqrt(2 * this.config.nLayer)), !0, this.MLPOUT));
18
+ }
19
+ forward(e, n) {
20
+ return t(() => {
21
+ this.startMemory();
22
+ let [t, a, s] = n.shape, c = i(r(n, [t * a, s]), this.getVariable(this.MLPHIDDEN), !1, !1, { activation: this.mlpConfig.activation ?? "gelu" }), l = i(c, this.getVariable(this.MLPOUT));
23
+ c.dispose();
24
+ let u = r(l, [
25
+ t,
26
+ a,
27
+ s
28
+ ]);
29
+ return this.endMemory("MLP"), e.dropout && e.dropout > 0 ? o(u, e.dropout) : u;
30
+ });
31
+ }
32
+ };
33
+ //#endregion
34
+ export { s as default };
@@ -0,0 +1,8 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { default as BaseLayer } from './BaseLayer';
3
+ import { GPTConfig, ModelForwardAttributes } from '../../main';
4
+ export default class PositionEmbedding extends BaseLayer {
5
+ private wpe?;
6
+ constructor(config: GPTConfig, name?: string, parent?: BaseLayer);
7
+ forward(attrs: ModelForwardAttributes, x: Tensor): Tensor;
8
+ }
@@ -0,0 +1,27 @@
1
+ import { Tt as e, _n as t, di as n, dt as r, vi as i } from "../dist-BewPQWjc.js";
2
+ import { n as a, t as o } from "../dist-VEU5mfO0.js";
3
+ import s from "./BaseLayer.js";
4
+ //#region lib/layers/PositionEmbedding.ts
5
+ var c = class extends s {
6
+ wpe;
7
+ constructor(e, t = "", n) {
8
+ super(e, n), this.wpe = o({
9
+ inputDim: this.config.blockSize,
10
+ outputDim: this.config.nEmbed,
11
+ name: t,
12
+ embeddingsInitializer: a({
13
+ mean: 0,
14
+ stddev: .02
15
+ })
16
+ });
17
+ }
18
+ forward(a, o) {
19
+ let s = a.cache?.[0]?.length ?? 0;
20
+ return n(() => {
21
+ let [, n] = o.shape, a = this.config.blockSize, c = e(i(r(0, n, 1, "int32"), t(s, "int32")), t(a, "int32")), l = this.wpe.apply(c);
22
+ return o.add(l);
23
+ });
24
+ }
25
+ };
26
+ //#endregion
27
+ export { c as default };
@@ -0,0 +1,12 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
3
+ import { GPTConfig } from '../../main';
4
+ export interface RMSNormConfig {
5
+ useGamma?: boolean;
6
+ }
7
+ export default class RMSNorm extends BaseLayer {
8
+ private GAMMA;
9
+ private rmsConfig;
10
+ constructor(config: GPTConfig, rmsConfig: RMSNormConfig, name?: string, parent?: BaseLayer);
11
+ forward(_: ForwardAttributes, x: Tensor): Tensor;
12
+ }