@genai-fi/nanogpt 0.19.0 → 0.20.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. package/package.json +9 -10
  2. package/dist/Generator.d.ts +0 -82
  3. package/dist/Generator.js +0 -11941
  4. package/dist/RealDiv-CGwv0liw.js +0 -365
  5. package/dist/Reshape-BW__R4mZ.js +0 -79
  6. package/dist/Reshape-CPBkTIH2.js +0 -14
  7. package/dist/TeachableLLM.d.ts +0 -70
  8. package/dist/TeachableLLM.js +0 -273
  9. package/dist/Trainer.d.ts +0 -43
  10. package/dist/Trainer.js +0 -244
  11. package/dist/_commonjsHelpers-ByX85dGu.js +0 -33
  12. package/dist/axis_util-GTVlo58H.js +0 -55
  13. package/dist/backend.d.ts +0 -2
  14. package/dist/backend.js +0 -13
  15. package/dist/backend_util-GaFarB78.js +0 -425
  16. package/dist/backend_webgpu-BqASlsbV.js +0 -545
  17. package/dist/binary_op_util-pKXltfxI.js +0 -192
  18. package/dist/broadcast_to-eS93CCN_.js +0 -28
  19. package/dist/checks/appendCache.d.ts +0 -1
  20. package/dist/checks/appendCache.js +0 -22
  21. package/dist/checks/attentionMask.d.ts +0 -1
  22. package/dist/checks/attentionMask.js +0 -37
  23. package/dist/checks/check.d.ts +0 -9
  24. package/dist/checks/check.js +0 -20
  25. package/dist/checks/gelu.d.ts +0 -1
  26. package/dist/checks/gelu.js +0 -18
  27. package/dist/checks/index.d.ts +0 -26
  28. package/dist/checks/index.js +0 -28
  29. package/dist/checks/matMulGelu.d.ts +0 -1
  30. package/dist/checks/matMulGelu.js +0 -28
  31. package/dist/checks/normRMS.d.ts +0 -1
  32. package/dist/checks/normRMS.js +0 -16
  33. package/dist/checks/normRMSGrad.d.ts +0 -1
  34. package/dist/checks/normRMSGrad.js +0 -12
  35. package/dist/checks/packUnpack.d.ts +0 -1
  36. package/dist/checks/packUnpack.js +0 -18
  37. package/dist/checks/qkv.d.ts +0 -1
  38. package/dist/checks/qkv.js +0 -34
  39. package/dist/checks/rope.d.ts +0 -1
  40. package/dist/checks/rope.js +0 -36
  41. package/dist/checks/weights.d.ts +0 -14
  42. package/dist/checks/weights.js +0 -31
  43. package/dist/clip_by_value-DDA7rrcT.js +0 -12
  44. package/dist/complex-DI35Q-gW.js +0 -11
  45. package/dist/complex_util-Yc1A_gV1.js +0 -55
  46. package/dist/concat-CAQpCret.js +0 -17
  47. package/dist/concat_util-D18dJ4fD.js +0 -22
  48. package/dist/data/docx.d.ts +0 -2
  49. package/dist/data/docx.js +0 -15
  50. package/dist/data/parquet.d.ts +0 -2
  51. package/dist/data/parquet.js +0 -17
  52. package/dist/data/pdf.d.ts +0 -2
  53. package/dist/data/pdf.js +0 -14
  54. package/dist/data/textLoader.d.ts +0 -7
  55. package/dist/data/textLoader.js +0 -108
  56. package/dist/dataset-CGGp1z9P.js +0 -1124
  57. package/dist/dropout_util--NxWuYg2.js +0 -27
  58. package/dist/expand_dims-Bkd1YD5x.js +0 -11
  59. package/dist/exports_initializers-CYzKLjN7.js +0 -7
  60. package/dist/floor-BQtb-Azg.js +0 -9
  61. package/dist/gather-qIqEqaGn.js +0 -9
  62. package/dist/gelu-B220X1Go.js +0 -26
  63. package/dist/gpgpu_math-BwvV12df.js +0 -2022
  64. package/dist/index-CUXkjxiT.js +0 -3516
  65. package/dist/index-CieiGp4Y.js +0 -349
  66. package/dist/index-CjOWnMXP.js +0 -7308
  67. package/dist/index-Cp39cXWe.js +0 -1016
  68. package/dist/index-D5v913EJ.js +0 -4
  69. package/dist/index-DmeWGGmS.js +0 -1074
  70. package/dist/index-DvYrXKkX.js +0 -113
  71. package/dist/index-Ksja3su6.js +0 -151
  72. package/dist/index-xuotMAFm.js +0 -118
  73. package/dist/inference/types.d.ts +0 -16
  74. package/dist/inference/types.js +0 -1
  75. package/dist/jszip.min-BZhlzntC.js +0 -2313
  76. package/dist/kernel_funcs_utils-pq0CK9co.js +0 -306
  77. package/dist/layers/BaseLayer.d.ts +0 -44
  78. package/dist/layers/BaseLayer.js +0 -74
  79. package/dist/layers/CausalSelfAttention.d.ts +0 -39
  80. package/dist/layers/CausalSelfAttention.js +0 -86
  81. package/dist/layers/LoRA.d.ts +0 -14
  82. package/dist/layers/LoRA.js +0 -58
  83. package/dist/layers/MLP.d.ts +0 -17
  84. package/dist/layers/MLP.js +0 -44
  85. package/dist/layers/PositionEmbedding.d.ts +0 -8
  86. package/dist/layers/PositionEmbedding.js +0 -31
  87. package/dist/layers/RMSNorm.d.ts +0 -12
  88. package/dist/layers/RMSNorm.js +0 -22
  89. package/dist/layers/RoPECache.d.ts +0 -18
  90. package/dist/layers/RoPECache.js +0 -50
  91. package/dist/layers/TiedEmbedding.d.ts +0 -13
  92. package/dist/layers/TiedEmbedding.js +0 -36
  93. package/dist/layers/TransformerBlock.d.ts +0 -27
  94. package/dist/layers/TransformerBlock.js +0 -40
  95. package/dist/layers/WeightStore.d.ts +0 -20
  96. package/dist/layers/WeightStore.js +0 -76
  97. package/dist/loader/load.d.ts +0 -6
  98. package/dist/loader/load.js +0 -68
  99. package/dist/loader/loadHF.d.ts +0 -8
  100. package/dist/loader/loadHF.js +0 -22
  101. package/dist/loader/loadTransformers.d.ts +0 -4
  102. package/dist/loader/loadTransformers.js +0 -44
  103. package/dist/loader/loadZipMeta.d.ts +0 -3
  104. package/dist/loader/loadZipMeta.js +0 -16
  105. package/dist/loader/newZipLoad.d.ts +0 -3
  106. package/dist/loader/newZipLoad.js +0 -31
  107. package/dist/loader/oldZipLoad.d.ts +0 -9
  108. package/dist/loader/oldZipLoad.js +0 -80
  109. package/dist/loader/save.d.ts +0 -16
  110. package/dist/loader/save.js +0 -90
  111. package/dist/loader/types.d.ts +0 -67
  112. package/dist/loader/types.js +0 -1
  113. package/dist/main.d.ts +0 -50
  114. package/dist/main.js +0 -109
  115. package/dist/matMul16-BcVC_E62.js +0 -80
  116. package/dist/matMulGelu-JNLZqKQp.js +0 -163
  117. package/dist/mat_mul-DhG0Newp.js +0 -11
  118. package/dist/mod-CSdCpRjf.js +0 -11
  119. package/dist/models/NanoGPTV1.d.ts +0 -16
  120. package/dist/models/NanoGPTV1.js +0 -99
  121. package/dist/models/NanoGPTV2.d.ts +0 -16
  122. package/dist/models/NanoGPTV2.js +0 -90
  123. package/dist/models/config.d.ts +0 -27
  124. package/dist/models/config.js +0 -50
  125. package/dist/models/factory.d.ts +0 -3
  126. package/dist/models/factory.js +0 -16
  127. package/dist/models/model.d.ts +0 -44
  128. package/dist/models/model.js +0 -134
  129. package/dist/non_max_suppression_impl-B2W7YjZB.js +0 -102
  130. package/dist/not_equal-hurPF26l.js +0 -64
  131. package/dist/ones-BytntneX.js +0 -14
  132. package/dist/ops/adamAdjust.d.ts +0 -2
  133. package/dist/ops/adamAdjust.js +0 -9
  134. package/dist/ops/adamMoments.d.ts +0 -2
  135. package/dist/ops/adamMoments.js +0 -9
  136. package/dist/ops/add16.d.ts +0 -2
  137. package/dist/ops/add16.js +0 -9
  138. package/dist/ops/appendCache.d.ts +0 -2
  139. package/dist/ops/appendCache.js +0 -22
  140. package/dist/ops/attentionMask.d.ts +0 -2
  141. package/dist/ops/attentionMask.js +0 -10
  142. package/dist/ops/concat16.d.ts +0 -2
  143. package/dist/ops/concat16.js +0 -9
  144. package/dist/ops/cpu/adamAdjust.d.ts +0 -1
  145. package/dist/ops/cpu/adamAdjust.js +0 -18
  146. package/dist/ops/cpu/adamMoments.d.ts +0 -1
  147. package/dist/ops/cpu/adamMoments.js +0 -16
  148. package/dist/ops/cpu/appendCache.d.ts +0 -1
  149. package/dist/ops/cpu/appendCache.js +0 -23
  150. package/dist/ops/cpu/attentionMask.d.ts +0 -1
  151. package/dist/ops/cpu/attentionMask.js +0 -22
  152. package/dist/ops/cpu/fusedSoftmax.d.ts +0 -9
  153. package/dist/ops/cpu/fusedSoftmax.js +0 -29
  154. package/dist/ops/cpu/gatherSub.d.ts +0 -1
  155. package/dist/ops/cpu/gatherSub.js +0 -18
  156. package/dist/ops/cpu/gelu.d.ts +0 -1
  157. package/dist/ops/cpu/gelu.js +0 -40
  158. package/dist/ops/cpu/matMul16.d.ts +0 -1
  159. package/dist/ops/cpu/matMul16.js +0 -15
  160. package/dist/ops/cpu/matMulGelu.d.ts +0 -1
  161. package/dist/ops/cpu/matMulGelu.js +0 -53
  162. package/dist/ops/cpu/matMulMul.d.ts +0 -1
  163. package/dist/ops/cpu/matMulMul.js +0 -23
  164. package/dist/ops/cpu/mulDropout.d.ts +0 -1
  165. package/dist/ops/cpu/mulDropout.js +0 -23
  166. package/dist/ops/cpu/normRMS.d.ts +0 -1
  167. package/dist/ops/cpu/normRMS.js +0 -39
  168. package/dist/ops/cpu/qkv.d.ts +0 -5
  169. package/dist/ops/cpu/qkv.js +0 -41
  170. package/dist/ops/cpu/rope.d.ts +0 -6
  171. package/dist/ops/cpu/rope.js +0 -38
  172. package/dist/ops/cpu/scatterSub.d.ts +0 -1
  173. package/dist/ops/cpu/scatterSub.js +0 -23
  174. package/dist/ops/dot16.d.ts +0 -2
  175. package/dist/ops/dot16.js +0 -42
  176. package/dist/ops/dropout.d.ts +0 -2
  177. package/dist/ops/dropout.js +0 -14
  178. package/dist/ops/dropout16.d.ts +0 -2
  179. package/dist/ops/dropout16.js +0 -25
  180. package/dist/ops/gatherSub.d.ts +0 -2
  181. package/dist/ops/gatherSub.js +0 -9
  182. package/dist/ops/gelu.d.ts +0 -3
  183. package/dist/ops/gelu.js +0 -8
  184. package/dist/ops/globalNorm.d.ts +0 -2
  185. package/dist/ops/globalNorm.js +0 -13
  186. package/dist/ops/grads/add16.d.ts +0 -1
  187. package/dist/ops/grads/add16.js +0 -26
  188. package/dist/ops/grads/attentionMask.d.ts +0 -1
  189. package/dist/ops/grads/attentionMask.js +0 -21
  190. package/dist/ops/grads/dropout16.d.ts +0 -1
  191. package/dist/ops/grads/dropout16.js +0 -2
  192. package/dist/ops/grads/gelu.d.ts +0 -2
  193. package/dist/ops/grads/gelu.js +0 -5
  194. package/dist/ops/grads/matMul16.d.ts +0 -2
  195. package/dist/ops/grads/matMul16.js +0 -9
  196. package/dist/ops/grads/matMulGelu.d.ts +0 -1
  197. package/dist/ops/grads/matMulGelu.js +0 -17
  198. package/dist/ops/grads/mul16.d.ts +0 -1
  199. package/dist/ops/grads/mul16.js +0 -4
  200. package/dist/ops/grads/normRMS.d.ts +0 -3
  201. package/dist/ops/grads/normRMS.js +0 -33
  202. package/dist/ops/grads/pack16.d.ts +0 -2
  203. package/dist/ops/grads/pack16.js +0 -6
  204. package/dist/ops/grads/qkv.d.ts +0 -3
  205. package/dist/ops/grads/qkv.js +0 -34
  206. package/dist/ops/grads/rope.d.ts +0 -2
  207. package/dist/ops/grads/rope.js +0 -5
  208. package/dist/ops/grads/softmax16.d.ts +0 -2
  209. package/dist/ops/grads/softmax16.js +0 -25
  210. package/dist/ops/grads/unpack16.d.ts +0 -2
  211. package/dist/ops/grads/unpack16.js +0 -5
  212. package/dist/ops/grads/utils.d.ts +0 -4
  213. package/dist/ops/grads/utils.js +0 -14
  214. package/dist/ops/log.d.ts +0 -0
  215. package/dist/ops/log.js +0 -1
  216. package/dist/ops/matMul16.d.ts +0 -15
  217. package/dist/ops/matMul16.js +0 -13
  218. package/dist/ops/matMulGelu.d.ts +0 -3
  219. package/dist/ops/matMulGelu.js +0 -14
  220. package/dist/ops/matMulMul.d.ts +0 -2
  221. package/dist/ops/matMulMul.js +0 -9
  222. package/dist/ops/mul16.d.ts +0 -2
  223. package/dist/ops/mul16.js +0 -39
  224. package/dist/ops/mulDrop.d.ts +0 -2
  225. package/dist/ops/mulDrop.js +0 -9
  226. package/dist/ops/normRMS.d.ts +0 -2
  227. package/dist/ops/normRMS.js +0 -19
  228. package/dist/ops/pack16.d.ts +0 -2
  229. package/dist/ops/pack16.js +0 -5
  230. package/dist/ops/qkv.d.ts +0 -2
  231. package/dist/ops/qkv.js +0 -10
  232. package/dist/ops/reshape16.d.ts +0 -2
  233. package/dist/ops/reshape16.js +0 -41
  234. package/dist/ops/rope.d.ts +0 -3
  235. package/dist/ops/rope.js +0 -7
  236. package/dist/ops/scatterSub.d.ts +0 -2
  237. package/dist/ops/scatterSub.js +0 -9
  238. package/dist/ops/slice16.d.ts +0 -2
  239. package/dist/ops/slice16.js +0 -9
  240. package/dist/ops/softmax16.d.ts +0 -2
  241. package/dist/ops/softmax16.js +0 -9
  242. package/dist/ops/sub16.d.ts +0 -2
  243. package/dist/ops/sub16.js +0 -8
  244. package/dist/ops/sum16.d.ts +0 -2
  245. package/dist/ops/sum16.js +0 -13
  246. package/dist/ops/transpose16.d.ts +0 -3
  247. package/dist/ops/transpose16.js +0 -40
  248. package/dist/ops/unpack16.d.ts +0 -2
  249. package/dist/ops/unpack16.js +0 -6
  250. package/dist/ops/webgl/adamAdjust.d.ts +0 -1
  251. package/dist/ops/webgl/adamAdjust.js +0 -49
  252. package/dist/ops/webgl/adamMoments.d.ts +0 -1
  253. package/dist/ops/webgl/adamMoments.js +0 -40
  254. package/dist/ops/webgl/appendCache.d.ts +0 -1
  255. package/dist/ops/webgl/appendCache.js +0 -44
  256. package/dist/ops/webgl/attentionMask.d.ts +0 -1
  257. package/dist/ops/webgl/attentionMask.js +0 -45
  258. package/dist/ops/webgl/dropout16.d.ts +0 -1
  259. package/dist/ops/webgl/dropout16.js +0 -11
  260. package/dist/ops/webgl/fusedSoftmax.d.ts +0 -11
  261. package/dist/ops/webgl/fusedSoftmax.js +0 -80
  262. package/dist/ops/webgl/gatherSub.d.ts +0 -1
  263. package/dist/ops/webgl/gatherSub.js +0 -27
  264. package/dist/ops/webgl/gelu.d.ts +0 -2
  265. package/dist/ops/webgl/gelu.js +0 -50
  266. package/dist/ops/webgl/log.d.ts +0 -17
  267. package/dist/ops/webgl/log.js +0 -23
  268. package/dist/ops/webgl/matMul16.d.ts +0 -1
  269. package/dist/ops/webgl/matMul16.js +0 -45
  270. package/dist/ops/webgl/matMulGelu.d.ts +0 -21
  271. package/dist/ops/webgl/matMulGelu.js +0 -9
  272. package/dist/ops/webgl/matMulMul.d.ts +0 -14
  273. package/dist/ops/webgl/matMulMul.js +0 -28
  274. package/dist/ops/webgl/mulDropout.d.ts +0 -1
  275. package/dist/ops/webgl/mulDropout.js +0 -41
  276. package/dist/ops/webgl/normRMS.d.ts +0 -1
  277. package/dist/ops/webgl/normRMS.js +0 -93
  278. package/dist/ops/webgl/qkv.d.ts +0 -1
  279. package/dist/ops/webgl/qkv.js +0 -46
  280. package/dist/ops/webgl/rope.d.ts +0 -1
  281. package/dist/ops/webgl/rope.js +0 -56
  282. package/dist/ops/webgl/scatterSub.d.ts +0 -1
  283. package/dist/ops/webgl/scatterSub.js +0 -27
  284. package/dist/ops/webgpu/adamAdjust.d.ts +0 -1
  285. package/dist/ops/webgpu/adamAdjust.js +0 -57
  286. package/dist/ops/webgpu/adamMoments.d.ts +0 -1
  287. package/dist/ops/webgpu/adamMoments.js +0 -60
  288. package/dist/ops/webgpu/add16.d.ts +0 -1
  289. package/dist/ops/webgpu/add16.js +0 -13
  290. package/dist/ops/webgpu/appendCache.d.ts +0 -1
  291. package/dist/ops/webgpu/appendCache.js +0 -105
  292. package/dist/ops/webgpu/attentionMask.d.ts +0 -1
  293. package/dist/ops/webgpu/attentionMask.js +0 -26
  294. package/dist/ops/webgpu/attentionMask32_program.d.ts +0 -19
  295. package/dist/ops/webgpu/attentionMask32_program.js +0 -54
  296. package/dist/ops/webgpu/clipScale.d.ts +0 -1
  297. package/dist/ops/webgpu/clipScale.js +0 -58
  298. package/dist/ops/webgpu/concat16.d.ts +0 -19
  299. package/dist/ops/webgpu/concat16.js +0 -126
  300. package/dist/ops/webgpu/dropout16.d.ts +0 -1
  301. package/dist/ops/webgpu/dropout16.js +0 -51
  302. package/dist/ops/webgpu/gatherSub.d.ts +0 -1
  303. package/dist/ops/webgpu/gatherSub.js +0 -39
  304. package/dist/ops/webgpu/gelu.d.ts +0 -14
  305. package/dist/ops/webgpu/gelu.js +0 -141
  306. package/dist/ops/webgpu/index.d.ts +0 -0
  307. package/dist/ops/webgpu/index.js +0 -26
  308. package/dist/ops/webgpu/matMul16.d.ts +0 -1
  309. package/dist/ops/webgpu/matMul16.js +0 -65
  310. package/dist/ops/webgpu/matMul16_program.d.ts +0 -42
  311. package/dist/ops/webgpu/matMul16_program.js +0 -343
  312. package/dist/ops/webgpu/mul16.d.ts +0 -1
  313. package/dist/ops/webgpu/mul16.js +0 -13
  314. package/dist/ops/webgpu/norm2.d.ts +0 -1
  315. package/dist/ops/webgpu/norm2.js +0 -76
  316. package/dist/ops/webgpu/normRMS.d.ts +0 -1
  317. package/dist/ops/webgpu/normRMS.js +0 -34
  318. package/dist/ops/webgpu/normRMS16_program.d.ts +0 -10
  319. package/dist/ops/webgpu/normRMS16_program.js +0 -25
  320. package/dist/ops/webgpu/normRMS32_program.d.ts +0 -10
  321. package/dist/ops/webgpu/normRMS32_program.js +0 -25
  322. package/dist/ops/webgpu/normRMSGrad.d.ts +0 -1
  323. package/dist/ops/webgpu/normRMSGrad.js +0 -284
  324. package/dist/ops/webgpu/pack16.d.ts +0 -1
  325. package/dist/ops/webgpu/pack16.js +0 -18
  326. package/dist/ops/webgpu/pack16_program.d.ts +0 -19
  327. package/dist/ops/webgpu/pack16_program.js +0 -92
  328. package/dist/ops/webgpu/qkv.d.ts +0 -1
  329. package/dist/ops/webgpu/qkv.js +0 -24
  330. package/dist/ops/webgpu/rope.d.ts +0 -1
  331. package/dist/ops/webgpu/rope.js +0 -135
  332. package/dist/ops/webgpu/scatterSub.d.ts +0 -1
  333. package/dist/ops/webgpu/scatterSub.js +0 -40
  334. package/dist/ops/webgpu/slice16.d.ts +0 -7
  335. package/dist/ops/webgpu/slice16.js +0 -69
  336. package/dist/ops/webgpu/softmax16.d.ts +0 -17
  337. package/dist/ops/webgpu/softmax16.js +0 -21
  338. package/dist/ops/webgpu/softmax16_program.d.ts +0 -13
  339. package/dist/ops/webgpu/softmax16_program.js +0 -73
  340. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +0 -17
  341. package/dist/ops/webgpu/softmax16_subgroup_program.js +0 -75
  342. package/dist/ops/webgpu/softmax16grad.d.ts +0 -1
  343. package/dist/ops/webgpu/softmax16grad.js +0 -37
  344. package/dist/ops/webgpu/sub16.d.ts +0 -1
  345. package/dist/ops/webgpu/sub16.js +0 -13
  346. package/dist/ops/webgpu/sum16.d.ts +0 -1
  347. package/dist/ops/webgpu/sum16.js +0 -38
  348. package/dist/ops/webgpu/transpose16.d.ts +0 -1
  349. package/dist/ops/webgpu/transpose16.js +0 -34
  350. package/dist/ops/webgpu/transpose16_program.d.ts +0 -16
  351. package/dist/ops/webgpu/transpose16_program.js +0 -50
  352. package/dist/ops/webgpu/transpose16_shared_program.d.ts +0 -15
  353. package/dist/ops/webgpu/transpose16_shared_program.js +0 -70
  354. package/dist/ops/webgpu/unpack16.d.ts +0 -1
  355. package/dist/ops/webgpu/unpack16.js +0 -48
  356. package/dist/ops/webgpu/utils/binary_op.d.ts +0 -35
  357. package/dist/ops/webgpu/utils/binary_op.js +0 -139
  358. package/dist/ops/webgpu/utils/deviceInfo.d.ts +0 -7
  359. package/dist/ops/webgpu/utils/deviceInfo.js +0 -11
  360. package/dist/ops/webgpu/utils/reductions.d.ts +0 -43
  361. package/dist/ops/webgpu/utils/reductions.js +0 -275
  362. package/dist/ops-CsXeTq1P.js +0 -476
  363. package/dist/pack16-bqltoUlR.js +0 -39
  364. package/dist/papaparse.min-C0cScC2i.js +0 -418
  365. package/dist/parquet-Bqjmp2vo.js +0 -44231
  366. package/dist/patches/webgpu_backend.d.ts +0 -18
  367. package/dist/patches/webgpu_backend.js +0 -56
  368. package/dist/patches/webgpu_base.d.ts +0 -21
  369. package/dist/patches/webgpu_base.js +0 -34
  370. package/dist/patches/webgpu_program.d.ts +0 -36
  371. package/dist/patches/webgpu_program.js +0 -400
  372. package/dist/pdf-NIhmP3sq.js +0 -19477
  373. package/dist/rand_util-CZ7yLoUm.js +0 -50
  374. package/dist/random_normal-IBRrha8a.js +0 -14
  375. package/dist/random_width-DN5ZtQkM.js +0 -9796
  376. package/dist/range-C-CjF-LI.js +0 -10
  377. package/dist/relu-J_X6MUzx.js +0 -9
  378. package/dist/reshape-BDOuCSNW.js +0 -9
  379. package/dist/resize_nearest_neighbor-BojqlfRe.js +0 -150
  380. package/dist/rope-DcrZM_e6.js +0 -24
  381. package/dist/scatter_nd_util-ByNJaL6I.js +0 -46
  382. package/dist/segment_util-Dasb2Zaf.js +0 -43
  383. package/dist/selu_util-BLhIqRkw.js +0 -44
  384. package/dist/shared-3agzAqQ_.js +0 -53
  385. package/dist/shared-CagdqkLh.js +0 -2143
  386. package/dist/slice-BzS11Qh0.js +0 -12
  387. package/dist/slice_util-CC35pLmT.js +0 -153
  388. package/dist/softmax-D4q1LJN7.js +0 -12
  389. package/dist/split-C2Sj255c.js +0 -9
  390. package/dist/squeeze-ho4wLUek.js +0 -10
  391. package/dist/stack-DudVrtmG.js +0 -11
  392. package/dist/step-BTxPtq1r.js +0 -261
  393. package/dist/sum-BpiwSWvg.js +0 -11
  394. package/dist/tensor-BWFldCso.js +0 -8
  395. package/dist/tensor1d-LMGMIUlr.js +0 -11
  396. package/dist/tensor2d-BnXMKScO.js +0 -14
  397. package/dist/tensor4d-C6UCG_u8.js +0 -14
  398. package/dist/tfjs_backend-BGnG-ppu.js +0 -654
  399. package/dist/tile-CFy-xTO6.js +0 -11
  400. package/dist/tokeniser/BaseTokeniser.d.ts +0 -33
  401. package/dist/tokeniser/BaseTokeniser.js +0 -124
  402. package/dist/tokeniser/CharTokeniser.d.ts +0 -24
  403. package/dist/tokeniser/CharTokeniser.js +0 -107
  404. package/dist/tokeniser/bpe.d.ts +0 -28
  405. package/dist/tokeniser/bpe.js +0 -173
  406. package/dist/tokeniser/messages.d.ts +0 -61
  407. package/dist/tokeniser/messages.js +0 -1
  408. package/dist/tokeniser/type.d.ts +0 -34
  409. package/dist/tokeniser/type.js +0 -1
  410. package/dist/training/AdamW.d.ts +0 -36
  411. package/dist/training/AdamW.js +0 -138
  412. package/dist/training/BasicTrainer.d.ts +0 -63
  413. package/dist/training/BasicTrainer.js +0 -265
  414. package/dist/training/DatasetBuilder.d.ts +0 -26
  415. package/dist/training/DatasetBuilder.js +0 -86
  416. package/dist/training/Evaluator.d.ts +0 -19
  417. package/dist/training/Evaluator.js +0 -39
  418. package/dist/training/LRScheduler.d.ts +0 -12
  419. package/dist/training/LRScheduler.js +0 -34
  420. package/dist/training/PreTrainer.d.ts +0 -11
  421. package/dist/training/PreTrainer.js +0 -20
  422. package/dist/training/SFTTrainer.d.ts +0 -12
  423. package/dist/training/SFTTrainer.js +0 -22
  424. package/dist/training/loss.d.ts +0 -3
  425. package/dist/training/loss.js +0 -24
  426. package/dist/training/orthoGrad.d.ts +0 -2
  427. package/dist/training/orthoGrad.js +0 -10
  428. package/dist/training/sparseCrossEntropy.d.ts +0 -7
  429. package/dist/training/sparseCrossEntropy.js +0 -69
  430. package/dist/training/tasks/ConversationTask.d.ts +0 -18
  431. package/dist/training/tasks/ConversationTask.js +0 -40
  432. package/dist/training/tasks/PretrainingTask.d.ts +0 -17
  433. package/dist/training/tasks/PretrainingTask.js +0 -47
  434. package/dist/training/tasks/StartSentenceTask.d.ts +0 -18
  435. package/dist/training/tasks/StartSentenceTask.js +0 -49
  436. package/dist/training/tasks/Task.d.ts +0 -22
  437. package/dist/training/tasks/Task.js +0 -68
  438. package/dist/training/tasks/splitter.d.ts +0 -5
  439. package/dist/training/tasks/splitter.js +0 -21
  440. package/dist/training/types.d.ts +0 -78
  441. package/dist/training/types.js +0 -1
  442. package/dist/training/validation.d.ts +0 -17
  443. package/dist/training/validation.js +0 -84
  444. package/dist/transpose-9kRxIXWR.js +0 -36
  445. package/dist/unsorted_segment_sum-DJvk5xnh.js +0 -277
  446. package/dist/utilities/arrayClose.d.ts +0 -1
  447. package/dist/utilities/arrayClose.js +0 -20
  448. package/dist/utilities/datasetID.d.ts +0 -2
  449. package/dist/utilities/datasetID.js +0 -21
  450. package/dist/utilities/dummy.d.ts +0 -9
  451. package/dist/utilities/dummy.js +0 -43
  452. package/dist/utilities/multinomialCPU.d.ts +0 -2
  453. package/dist/utilities/multinomialCPU.js +0 -13
  454. package/dist/utilities/naming.d.ts +0 -4
  455. package/dist/utilities/naming.js +0 -1
  456. package/dist/utilities/packed.d.ts +0 -4
  457. package/dist/utilities/packed.js +0 -15
  458. package/dist/utilities/parameters.d.ts +0 -11
  459. package/dist/utilities/parameters.js +0 -57
  460. package/dist/utilities/performance.d.ts +0 -2
  461. package/dist/utilities/performance.js +0 -16
  462. package/dist/utilities/profile.d.ts +0 -17
  463. package/dist/utilities/profile.js +0 -38
  464. package/dist/utilities/safetensors.d.ts +0 -3
  465. package/dist/utilities/safetensors.js +0 -83
  466. package/dist/utilities/sentences.d.ts +0 -5
  467. package/dist/utilities/sentences.js +0 -41
  468. package/dist/utilities/tokenParse.d.ts +0 -1
  469. package/dist/utilities/tokenParse.js +0 -21
  470. package/dist/utilities/topP.d.ts +0 -1
  471. package/dist/utilities/topP.js +0 -13
  472. package/dist/utilities/waitForModel.d.ts +0 -2
  473. package/dist/utilities/waitForModel.js +0 -12
  474. package/dist/utilities/weights.d.ts +0 -12
  475. package/dist/utilities/weights.js +0 -45
  476. package/dist/utilities/yielder.d.ts +0 -1
  477. package/dist/utilities/yielder.js +0 -7
  478. package/dist/variable-Ck482e3n.js +0 -7
  479. package/dist/webgpu_program-B4HmApL1.js +0 -525
  480. package/dist/webgpu_util-DYlGSwOJ.js +0 -64
  481. package/dist/zeros-DvZpK8s6.js +0 -13
  482. package/dist/zeros_like-CWjDdwr-.js +0 -721
@@ -1,306 +0,0 @@
1
- import { _ as B, U as G, aU as K, a7 as W, aH as z, aV as V, ab as N, aI as F, am as S } from "./index-CUXkjxiT.js";
2
- import { u as O, f as H } from "./gpgpu_math-BwvV12df.js";
3
- import { f as v } from "./backend_util-GaFarB78.js";
4
- function Y(t, e) {
5
- return ["x", "y", "z", "w", "u", "v"].slice(0, e).map((s) => `${t}.${s}`);
6
- }
7
- function Z(t, e) {
8
- return e === 1 ? [t] : Y(t, e);
9
- }
10
- function pe(t, e) {
11
- if (t === 1)
12
- return "rc";
13
- let s = "";
14
- for (let r = 0; r < t; r++)
15
- s += e[r], r < t - 1 && (s += ",");
16
- return s;
17
- }
18
- class q {
19
- constructor(e, s) {
20
- this.variableNames = ["A"], this.outputShape = e, this.enableShapeUniforms = O(this.outputShape.length), this.userCode = `
21
- float unaryOperation(float x) {
22
- ${s}
23
- }
24
-
25
- void main() {
26
- float x = getAAtOutCoords();
27
- float y = unaryOperation(x);
28
-
29
- setOutput(y);
30
- }
31
- `;
32
- }
33
- }
34
- const T = "if (isnan(x)) return x;", M = "return x;", de = "return abs(x);", j = "return (x >= 0.0) ? x : (exp(x) - 1.0);", J = T + `
35
- return (x < 0.0) ? 0.0 : x;
36
- `, Q = T + `
37
- return (x < 0.0) ? 0.0 : min(6.0, x);
38
- `, he = "return x;", X = "return 1.0 / (1.0 + exp(-1.0 * x));";
39
- const ee = "return x;", te = `
40
- vec4 result;
41
-
42
- result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
43
- result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
44
- result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
45
- result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
46
-
47
- return result;
48
- `, se = `
49
- vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
50
- bvec4 isNaN = isnan(x);
51
-
52
- result.r = isNaN.r ? x.r : result.r;
53
- result.g = isNaN.g ? x.g : result.g;
54
- result.b = isNaN.b ? x.b : result.b;
55
- result.a = isNaN.a ? x.a : result.a;
56
-
57
- return result;
58
- `, ae = `
59
- vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
60
- bvec4 isNaN = isnan(x);
61
-
62
- result.r = isNaN.r ? x.r : result.r;
63
- result.g = isNaN.g ? x.g : result.g;
64
- result.b = isNaN.b ? x.b : result.b;
65
- result.a = isNaN.a ? x.a : result.a;
66
-
67
- return result;
68
- `, re = "return 1.0 / (1.0 + exp(-1.0 * x));";
69
- class ne {
70
- constructor(e, s) {
71
- this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = e, this.enableShapeUniforms = O(this.outputShape.length), this.userCode = `
72
- vec4 unaryOperation(vec4 x) {
73
- ${s}
74
- }
75
-
76
- void main() {
77
- vec4 x = getAAtOutCoords();
78
- vec4 y = unaryOperation(x);
79
-
80
- setOutput(y);
81
- }
82
- `;
83
- }
84
- }
85
- const fe = `
86
- if (isnan(a)) return a;
87
- if (isnan(b)) return b;
88
- `;
89
- class b {
90
- constructor(e, s, r) {
91
- this.variableNames = ["A", "B"], this.outputShape = B(s, r), this.enableShapeUniforms = O(this.outputShape.length), this.userCode = `
92
- float binaryOperation(float a, float b) {
93
- ${e}
94
- }
95
-
96
- void main() {
97
- float a = getAAtOutCoords();
98
- float b = getBAtOutCoords();
99
- setOutput(binaryOperation(a, b));
100
- }
101
- `;
102
- }
103
- }
104
- const xe = `
105
- result.r = isNaN.r ? NAN : result.r;
106
- result.g = isNaN.g ? NAN : result.g;
107
- result.b = isNaN.b ? NAN : result.b;
108
- result.a = isNaN.a ? NAN : result.a;
109
- `;
110
- class E {
111
- constructor(e, s, r, u = !1) {
112
- this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = B(s, r);
113
- const n = this.outputShape.length;
114
- this.enableShapeUniforms = O(n);
115
- let o = "";
116
- if (u)
117
- if (n === 0 || G(this.outputShape) === 1)
118
- o = `
119
- result.y = 0.;
120
- result.z = 0.;
121
- result.w = 0.;
122
- `;
123
- else if (o = `
124
- ${H(n)} coords = getOutputCoords();
125
- `, n === 1)
126
- this.enableShapeUniforms ? o += `
127
- result.y = (coords + 1) >= outShape ? 0. : result.y;
128
- result.z = 0.;
129
- result.w = 0.;
130
- ` : o += `
131
- result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
132
- result.z = 0.;
133
- result.w = 0.;
134
- `;
135
- else {
136
- const a = Z("coords", n);
137
- this.enableShapeUniforms ? o += `
138
- bool nextRowOutOfBounds =
139
- (${a[n - 2]} + 1) >= outShape[${n} - 2];
140
- bool nextColOutOfBounds =
141
- (${a[n - 1]} + 1) >= outShape[${n} - 1];
142
- result.y = nextColOutOfBounds ? 0. : result.y;
143
- result.z = nextRowOutOfBounds ? 0. : result.z;
144
- result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
145
- ` : o += `
146
- bool nextRowOutOfBounds =
147
- (${a[n - 2]} + 1) >= ${this.outputShape[n - 2]};
148
- bool nextColOutOfBounds =
149
- (${a[n - 1]} + 1) >= ${this.outputShape[n - 1]};
150
- result.y = nextColOutOfBounds ? 0. : result.y;
151
- result.z = nextRowOutOfBounds ? 0. : result.z;
152
- result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
153
- `;
154
- }
155
- this.userCode = `
156
- vec4 binaryOperation(vec4 a, vec4 b) {
157
- ${e}
158
- }
159
-
160
- void main() {
161
- vec4 a = getAAtOutCoords();
162
- vec4 b = getBAtOutCoords();
163
-
164
- vec4 result = binaryOperation(a, b);
165
- ${o}
166
-
167
- setOutput(result);
168
- }
169
- `;
170
- }
171
- }
172
- function P(t) {
173
- const { inputs: e, backend: s } = t, { x: r } = e;
174
- return s.incRef(r.dataId), { dataId: r.dataId, shape: r.shape, dtype: r.dtype };
175
- }
176
- const ge = {
177
- kernelName: K,
178
- backendName: "webgl",
179
- kernelFunc: P
180
- };
181
- function L(t) {
182
- const { inputs: e, backend: s } = t, { real: r, imag: u } = e, n = s.makeTensorInfo(r.shape, "complex64"), o = s.texData.get(n.dataId), i = P({ inputs: { x: r }, backend: s }), a = P({ inputs: { x: u }, backend: s });
183
- return o.complexTensorInfos = { real: i, imag: a }, n;
184
- }
185
- const me = {
186
- kernelName: W,
187
- backendName: "webgl",
188
- kernelFunc: L
189
- };
190
- const w = "return (a < 0.) ? b * a : a;", R = `
191
- vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
192
- return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
193
- `;
194
- function oe(t) {
195
- const { inputs: e, backend: s, attrs: r } = t, { x: u } = e, { alpha: n } = r, o = s.makeTensorInfo([], "float32", V(n, "float32")), i = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(R, u.shape, o.shape) : new b(w, u.shape, o.shape), a = s.runWebGLProgram(i, [u, o], "float32");
196
- return s.disposeIntermediateTensorInfo(o), a;
197
- }
198
- const be = {
199
- kernelName: z,
200
- backendName: "webgl",
201
- kernelFunc: oe
202
- };
203
- const U = "return (a < 0.) ? b * a : a;", k = `
204
- vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
205
- return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
206
- `;
207
- function ue(t) {
208
- const { inputs: e, backend: s } = t, { x: r, alpha: u } = e, n = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new E(k, r.shape, u.shape) : new b(U, r.shape, u.shape);
209
- return s.runWebGLProgram(n, [r, u], "float32");
210
- }
211
- const Ne = {
212
- kernelName: F,
213
- backendName: "webgl",
214
- kernelFunc: ue
215
- };
216
- const Oe = "if (isnan(x)) return x;";
217
- function ye({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: s, dtype: r }) {
218
- return ({ inputs: u, backend: n }) => {
219
- const { x: o } = u, i = n, a = r || o.dtype;
220
- if (i.shouldExecuteOnCPU([o]) && s != null) {
221
- const d = i.texData.get(o.dataId), y = s(d.values, a);
222
- return i.makeTensorInfo(o.shape, a, y);
223
- }
224
- const c = N().getBool("WEBGL_PACK_UNARY_OPERATIONS") && e != null;
225
- let l;
226
- return c ? l = new ne(o.shape, e) : l = new q(o.shape, t), i.runWebGLProgram(l, [o], a);
227
- };
228
- }
229
- function Ie({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: s = !1, supportsComplex: r = !1, cpuKernelImpl: u, dtype: n }) {
230
- return ({ inputs: o, backend: i }) => {
231
- const { a, b: c } = o, l = i;
232
- if (r && a.dtype === "complex64") {
233
- const h = l.texData.get(a.dataId), f = l.texData.get(c.dataId), [g, m] = [
234
- [h.complexTensorInfos.real, f.complexTensorInfos.real],
235
- [h.complexTensorInfos.imag, f.complexTensorInfos.imag]
236
- ].map((C) => {
237
- const [p, x] = C, $ = {
238
- dataId: p.dataId,
239
- dtype: p.dtype,
240
- shape: a.shape
241
- }, _ = {
242
- dataId: x.dataId,
243
- dtype: x.dtype,
244
- shape: c.shape
245
- }, D = new b(t, a.shape, c.shape);
246
- return l.runWebGLProgram(D, [$, _], S(p.dtype, x.dtype));
247
- }), A = L({ inputs: { real: g, imag: m }, backend: l });
248
- return l.disposeIntermediateTensorInfo(g), l.disposeIntermediateTensorInfo(m), A;
249
- }
250
- const d = n || S(a.dtype, c.dtype);
251
- if ((a.dtype === "string" || c.dtype === "string" || l.shouldExecuteOnCPU([a, c])) && u != null) {
252
- const h = l.texData.get(a.dataId).values, f = l.texData.get(c.dataId).values, g = a.dtype === "string" ? (
253
- // tslint:disable-next-line: no-any
254
- v(h)
255
- ) : h, m = a.dtype === "string" ? (
256
- // tslint:disable-next-line: no-any
257
- v(f)
258
- ) : f, [A, C] = u(a.shape, c.shape, g, m, d), p = l.makeTensorInfo(C, d), x = l.texData.get(p.dataId);
259
- return x.values = A, p;
260
- }
261
- const y = N().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
262
- let I;
263
- return y ? I = new E(e, a.shape, c.shape, s) : I = new b(t, a.shape, c.shape), l.runWebGLProgram(I, [a, c], d);
264
- };
265
- }
266
- function Ae(t, e = !1) {
267
- if (t === "linear")
268
- return e ? ee : M;
269
- if (t === "relu")
270
- return e ? se : J;
271
- if (t === "elu")
272
- return e ? te : j;
273
- if (t === "relu6")
274
- return e ? ae : Q;
275
- if (t === "prelu")
276
- return e ? k : U;
277
- if (t === "leakyrelu")
278
- return e ? R : w;
279
- if (t === "sigmoid")
280
- return e ? re : X;
281
- throw new Error(`Activation ${t} has not been implemented for the WebGL backend.`);
282
- }
283
- export {
284
- de as A,
285
- E as B,
286
- T as C,
287
- ne as U,
288
- Z as a,
289
- Ie as b,
290
- pe as c,
291
- he as d,
292
- q as e,
293
- L as f,
294
- Y as g,
295
- b as h,
296
- P as i,
297
- fe as j,
298
- xe as k,
299
- Oe as l,
300
- Ae as m,
301
- me as n,
302
- ge as o,
303
- be as p,
304
- Ne as q,
305
- ye as u
306
- };
@@ -1,44 +0,0 @@
1
- import { GPTConfig } from '../models/config';
2
- import { default as MemoryProfiler } from '../utilities/profile';
3
- import { default as RoPECache } from './RoPECache';
4
- import { Tensor, Variable } from '@tensorflow/tfjs-core';
5
- import { default as WeightStore } from './WeightStore';
6
- export interface ForwardAttributes {
7
- training: boolean;
8
- checkpointing?: boolean;
9
- mixedPrecision?: boolean;
10
- ropeCache?: RoPECache;
11
- outputEmbeddings?: boolean;
12
- embeddings?: {
13
- name: string;
14
- tensor: Tensor;
15
- }[];
16
- dropout?: number;
17
- layerDrop?: number;
18
- }
19
- export default abstract class BaseLayer<ATTR extends ForwardAttributes = ForwardAttributes, CONFIG extends GPTConfig = GPTConfig> {
20
- readonly parent?: BaseLayer;
21
- readonly config: CONFIG;
22
- weightStore: WeightStore;
23
- readonly children: BaseLayer[];
24
- private profiler?;
25
- private ownVariables;
26
- constructor(config: CONFIG, parent?: BaseLayer);
27
- getProfiler(): MemoryProfiler | undefined;
28
- setProfiler(profiler: MemoryProfiler | null): void;
29
- startMemory(): void;
30
- endMemory(label: string): void;
31
- addVariable(name: string, variable?: Variable): void;
32
- addChildVariable(name: string): void;
33
- get variables(): Variable[];
34
- get trainableVariables(): Variable[];
35
- getVariable(name: string): Tensor;
36
- hasVariable(name: string): boolean;
37
- setVariable(name: string, variable: Variable): void;
38
- dispose(): void;
39
- protected build(): void;
40
- abstract forward(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
41
- call(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
42
- callCheckpoint(attrs: ATTR, ...x: Tensor[]): Tensor;
43
- private checkpointingFn;
44
- }
@@ -1,74 +0,0 @@
1
- import { a2 as p, h as s, a4 as g } from "../index-CUXkjxiT.js";
2
- import b from "./WeightStore.js";
3
- class T {
4
- parent;
5
- config;
6
- weightStore;
7
- children = [];
8
- profiler;
9
- ownVariables = /* @__PURE__ */ new Set();
10
- constructor(t, e) {
11
- this.config = t, this.parent = e, this.parent ? (this.parent.children.push(this), this.weightStore = this.parent.weightStore) : this.weightStore = new b();
12
- }
13
- getProfiler() {
14
- return this.profiler;
15
- }
16
- setProfiler(t) {
17
- this.profiler = t || void 0, this.children.forEach((e) => {
18
- e.setProfiler(t);
19
- });
20
- }
21
- startMemory() {
22
- this.profiler?.startMemory();
23
- }
24
- endMemory(t) {
25
- this.profiler?.endMemory(t);
26
- }
27
- addVariable(t, e) {
28
- this.weightStore.addVariable(t, e), this.ownVariables.add(t), this.parent && this.parent.addChildVariable(t);
29
- }
30
- addChildVariable(t) {
31
- this.ownVariables.add(t);
32
- }
33
- get variables() {
34
- return this.weightStore.variables;
35
- }
36
- get trainableVariables() {
37
- return this.weightStore.trainableVariables.filter((t) => this.ownVariables.has(t.name));
38
- }
39
- getVariable(t) {
40
- return this.weightStore.getVariable(t);
41
- }
42
- hasVariable(t) {
43
- return this.weightStore.hasVariable(t);
44
- }
45
- setVariable(t, e) {
46
- this.weightStore.setVariable(t, e);
47
- }
48
- dispose() {
49
- this.weightStore.dispose();
50
- }
51
- build() {
52
- }
53
- call(t, ...e) {
54
- return this.build(), this.forward(t, ...e);
55
- }
56
- callCheckpoint(t, ...e) {
57
- return this.build(), this.checkpointingFn(t, ...e);
58
- }
59
- checkpointingFn(t, ...e) {
60
- const r = this.trainableVariables;
61
- return p((...i) => {
62
- const o = i[i.length - 1], a = i.slice(0, e.length), h = this.forward(t, ...a);
63
- return o(a), { value: h, gradFunc: (n, l) => {
64
- const c = s().state.activeTape;
65
- s().state.activeTape = [];
66
- const d = g((...u) => this.forward(t, ...u.slice(0, a.length)))([...l, ...r], n);
67
- return s().state.activeTape = c, d;
68
- } };
69
- })(...e, ...r);
70
- }
71
- }
72
- export {
73
- T as default
74
- };
@@ -1,39 +0,0 @@
1
- import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
2
- import { Tensor } from '@tensorflow/tfjs-core';
3
- import { GPTConfig } from '../models/config';
4
- export interface KVCache {
5
- k?: Tensor;
6
- v?: Tensor;
7
- length: number;
8
- cumulativeLength: number;
9
- }
10
- export interface AttentionScores {
11
- meanOfHeads?: boolean;
12
- attentionOut?: Tensor[];
13
- }
14
- interface AttentionForwardAttributes extends ForwardAttributes {
15
- attentionScores?: AttentionScores;
16
- pastKV?: KVCache;
17
- seed?: number;
18
- ropePositionOffset?: number;
19
- }
20
- export interface CausalSelfAttentionConfig {
21
- useQKNorm?: boolean;
22
- }
23
- export default class CausalSelfAttention extends BaseLayer<AttentionForwardAttributes> {
24
- private readonly attentionConfig;
25
- private divisor;
26
- private index;
27
- private units;
28
- private projUnits;
29
- private ATTN;
30
- private PROJ;
31
- constructor(index: number, config: GPTConfig, attentionConfig: CausalSelfAttentionConfig, parent?: BaseLayer);
32
- protected build(): void;
33
- private getAttentionScores;
34
- private getQKV;
35
- private getOutputProjection;
36
- private updateCache;
37
- forward(attr: AttentionForwardAttributes, x: Tensor): Tensor;
38
- }
39
- export {};
@@ -1,86 +0,0 @@
1
- import { attentionMask as R } from "../ops/attentionMask.js";
2
- import J from "./BaseLayer.js";
3
- import { r as v } from "../rope-DcrZM_e6.js";
4
- import { appendCache as A } from "../ops/appendCache.js";
5
- import { k as c, t as L } from "../index-CUXkjxiT.js";
6
- import { softmax16 as y } from "../ops/softmax16.js";
7
- import { b as M } from "../matMul16-BcVC_E62.js";
8
- import { p as K } from "../pack16-bqltoUlR.js";
9
- import { transpose16 as j } from "../ops/transpose16.js";
10
- import { dot16 as E } from "../ops/dot16.js";
11
- import { reshape16 as _ } from "../ops/reshape16.js";
12
- import { isPackedTensor as f } from "../utilities/packed.js";
13
- import { qkv as q } from "../ops/qkv.js";
14
- import { normRMS as O } from "../ops/normRMS.js";
15
- import { dropout16 as x } from "../ops/dropout16.js";
16
- import { v as P } from "../variable-Ck482e3n.js";
17
- import { r as S } from "../random_normal-IBRrha8a.js";
18
- class it extends J {
19
- constructor(t, o, s, i) {
20
- super(o, i), this.attentionConfig = s, this.index = t, this.units = o.nEmbed * 3, this.projUnits = o.nEmbed, this.ATTN = `block_${this.index}_cAttn`, this.PROJ = `block_${this.index}_cProj`, this.addVariable(this.ATTN), this.addVariable(this.PROJ), this.divisor = 1 / Math.sqrt(o.nEmbed / o.nHead);
21
- }
22
- divisor;
23
- index;
24
- units;
25
- projUnits;
26
- ATTN;
27
- PROJ;
28
- build() {
29
- this.hasVariable(this.ATTN) === !1 && this.setVariable(
30
- this.ATTN,
31
- P(S([this.config.nEmbed, this.units], 0, 0.02), !0, this.ATTN)
32
- ), this.hasVariable(this.PROJ) === !1 && this.setVariable(
33
- this.PROJ,
34
- P(S([this.projUnits, this.config.nEmbed], 0, 0.02), !0, this.PROJ)
35
- );
36
- }
37
- getAttentionScores(t, o, s) {
38
- const i = R(t, o, this.divisor, s), e = y(i);
39
- return i.dispose(), e;
40
- }
41
- getQKV(t) {
42
- const o = f(t) ? K(this.getVariable(this.ATTN)) : this.getVariable(this.ATTN), s = q(t, o, this.config.nHead);
43
- return f(t) && o.dispose(), s;
44
- }
45
- getOutputProjection(t) {
46
- const o = t.shape[0], s = t.shape[2], i = this.config.nEmbed, e = f(t), r = j(t, [0, 2, 1, 3]), n = _(r, [o, s, e ? i / 2 : i]);
47
- r.dispose();
48
- const p = e ? K(this.getVariable(this.PROJ)) : this.getVariable(this.PROJ), a = E(n, p);
49
- return e && p.dispose(), n.dispose(), a;
50
- }
51
- updateCache(t, o, s) {
52
- const i = this.config.blockSize, e = t.shape[2], r = s.length || 0, n = A(t, i, r, s.k);
53
- t.dispose(), s.k && s.k.dispose();
54
- const p = A(o, i, r, s.v);
55
- o.dispose(), s.v && s.v.dispose();
56
- const a = Math.min(r + e, i), h = s.cumulativeLength + e;
57
- s.length = a, s.cumulativeLength = h, s.k = c(n), s.v = c(p);
58
- }
59
- forward(t, o) {
60
- return L(() => {
61
- this.startMemory();
62
- const [s, i, e] = this.getQKV(o), r = t.pastKV ? t.pastKV.cumulativeLength : t.ropePositionOffset || 0, n = t.ropeCache, p = n ? v(s, n, r) : s, a = n ? v(i, n, r) : i, h = this.attentionConfig.useQKNorm ?? !1, m = h ? O(p) : p;
63
- h && p.dispose();
64
- const l = h ? O(a) : a;
65
- h && a.dispose(), n && (s.dispose(), i.dispose());
66
- const T = t.pastKV ? t.pastKV.length : 0;
67
- t.pastKV && !t.training && this.updateCache(l, e, t.pastKV);
68
- const u = t.pastKV?.k ? t.pastKV.k : l, V = t.pastKV?.v ? t.pastKV.v : e;
69
- let d;
70
- T > 0 ? d = this.getAttentionScores(m, u, T) : d = this.getAttentionScores(m, u), m.dispose(), t.pastKV || u.dispose();
71
- const g = M(d, V), b = t.attentionScores !== void 0 && t.attentionScores.attentionOut !== void 0;
72
- b || d.dispose(), t.pastKV || V.dispose();
73
- const k = this.getOutputProjection(g);
74
- if (g.dispose(), b && t.attentionScores && t.attentionScores.attentionOut !== void 0) {
75
- const N = d.shape[1], C = d.shape[2];
76
- t.attentionScores.attentionOut?.push(
77
- c(d.slice([0, 0, 0, 0], [1, -1, -1, -1]).reshape([N, C, -1]))
78
- );
79
- }
80
- return this.endMemory("CausalSelfAttention"), t.dropout && t.dropout > 0 ? x(k, t.dropout) : k;
81
- });
82
- }
83
- }
84
- export {
85
- it as default
86
- };
@@ -1,14 +0,0 @@
1
- import { default as WeightStore } from './WeightStore';
2
- export default class LoRA {
3
- private weightStore;
4
- readonly alpha: number;
5
- readonly rank: number;
6
- readonly variables: Set<string>;
7
- private scale;
8
- readonly name: string;
9
- constructor(name: string, weightStore: WeightStore, alpha: number, rank: number, variables: string[]);
10
- attach(): void;
11
- merge(): void;
12
- detach(): void;
13
- dispose(): void;
14
- }
@@ -1,58 +0,0 @@
1
- import { a as m, t as n } from "../index-CUXkjxiT.js";
2
- import { p } from "../index-DmeWGGmS.js";
3
- import { v as g } from "../variable-Ck482e3n.js";
4
- import { r as S } from "../random_normal-IBRrha8a.js";
5
- import { z as _ } from "../zeros-DvZpK8s6.js";
6
- class B {
7
- weightStore;
8
- alpha;
9
- rank;
10
- variables;
11
- scale;
12
- name;
13
- constructor(t, e, a, s, r) {
14
- this.name = t, this.weightStore = e, this.alpha = a, this.rank = s;
15
- const c = p(r), w = e.variableNames.filter(
16
- (i) => c(i) && !i.endsWith("_loraA") && !i.endsWith("_loraB")
17
- );
18
- this.variables = new Set(w), this.scale = m(a / s), this.variables.forEach((i) => {
19
- const o = this.weightStore.getRawVariable(i), [d, b] = o.shape, h = `${i}_${this.name}_loraA`, l = `${i}_${this.name}_loraB`;
20
- if (o.shape.length !== 2) {
21
- console.warn(
22
- `LoRA currently only supports 2D weight matrices. Variable ${i} has shape ${o.shape}`
23
- ), this.variables.delete(i);
24
- return;
25
- }
26
- this.weightStore.hasVariable(h) || this.weightStore.hasVariable(l) || (this.weightStore.addVariable(
27
- h,
28
- g(S([d, this.rank], 0, 0.02), !0, h)
29
- ), this.weightStore.addVariable(l, g(_([this.rank, b]), !0, l)));
30
- });
31
- }
32
- attach() {
33
- if (this.weightStore.onWeightRead)
34
- throw new Error("LoRA cannot be applied to a WeightStore that already has a onWeightRead hook.");
35
- this.weightStore.onWeightRead = (t, e) => this.variables.has(t) ? n(() => {
36
- const a = this.weightStore.getRawVariable(`${t}_${this.name}_loraA`), s = this.weightStore.getRawVariable(`${t}_${this.name}_loraB`);
37
- return e.add(a.matMul(s).mul(this.scale));
38
- }) : e, this.weightStore.setTrainable([`*_${this.name}_loraA`, `*_${this.name}_loraB`]);
39
- }
40
- merge() {
41
- this.variables.forEach((t) => {
42
- const e = this.weightStore.getRawVariable(t), a = this.weightStore.getRawVariable(`${t}_${this.name}_loraA`), s = this.weightStore.getRawVariable(`${t}_${this.name}_loraB`), r = n(() => e.add(a.matMul(s).mul(this.scale)));
43
- e.assign(r), r.dispose();
44
- });
45
- }
46
- detach() {
47
- this.weightStore.onWeightRead = void 0, this.weightStore.setTrainable(["*"]);
48
- }
49
- dispose() {
50
- this.detach(), this.scale.dispose(), this.variables.forEach((t) => {
51
- const e = `${t}_${this.name}_loraA`, a = `${t}_${this.name}_loraB`;
52
- this.weightStore.getRawVariable(e).dispose(), this.weightStore.getRawVariable(a).dispose(), this.weightStore.deleteVariable(e), this.weightStore.deleteVariable(a);
53
- }), this.variables.clear();
54
- }
55
- }
56
- export {
57
- B as default
58
- };
@@ -1,17 +0,0 @@
1
- import { Tensor } from '@tensorflow/tfjs-core';
2
- import { default as BaseLayer, ForwardAttributes } from './BaseLayer';
3
- import { GPTConfig } from '../main';
4
- export interface MLPConfig {
5
- activation?: 'gelu' | 'relu2';
6
- hiddenFactor?: number;
7
- }
8
- export default class MLP extends BaseLayer {
9
- private index;
10
- private hiddenUnits;
11
- private MLPHIDDEN;
12
- private MLPOUT;
13
- private mlpConfig;
14
- constructor(index: number, config: GPTConfig, mlpConfig: MLPConfig, parent?: BaseLayer);
15
- protected build(): void;
16
- forward(attr: ForwardAttributes, x: Tensor): Tensor;
17
- }
@@ -1,44 +0,0 @@
1
- import { t as M } from "../index-CUXkjxiT.js";
2
- import f from "./BaseLayer.js";
3
- import { b as h } from "../matMul16-BcVC_E62.js";
4
- import { reshape16 as d } from "../ops/reshape16.js";
5
- import { dropout16 as L } from "../ops/dropout16.js";
6
- import { v as n } from "../variable-Ck482e3n.js";
7
- import { r as m } from "../random_normal-IBRrha8a.js";
8
- class N extends f {
9
- index;
10
- hiddenUnits;
11
- MLPHIDDEN;
12
- MLPOUT;
13
- mlpConfig;
14
- constructor(i, t, s, e) {
15
- super(t, e), this.index = i, this.mlpConfig = s, this.hiddenUnits = (s.hiddenFactor ?? t.mlpFactor) * t.nEmbed, this.MLPHIDDEN = `block_${this.index}_mlpHidden`, this.MLPOUT = `block_${this.index}_mlpOut`, this.addVariable(this.MLPHIDDEN), this.addVariable(this.MLPOUT);
16
- }
17
- build() {
18
- this.hasVariable(this.MLPHIDDEN) === !1 && this.setVariable(
19
- this.MLPHIDDEN,
20
- n(m([this.config.nEmbed, this.hiddenUnits], 0, 0.02), !0, this.MLPHIDDEN)
21
- ), this.hasVariable(this.MLPOUT) === !1 && this.setVariable(
22
- this.MLPOUT,
23
- n(
24
- m([this.hiddenUnits, this.config.nEmbed], 0, 0.02 / Math.sqrt(2 * this.config.nLayer)),
25
- !0,
26
- this.MLPOUT
27
- )
28
- );
29
- }
30
- forward(i, t) {
31
- return M(() => {
32
- this.startMemory();
33
- const [s, e, r] = t.shape, l = d(t, [s * e, r]), a = h(l, this.getVariable(this.MLPHIDDEN), !1, !1, {
34
- activation: this.mlpConfig.activation ?? "gelu"
35
- }), p = h(a, this.getVariable(this.MLPOUT));
36
- a.dispose();
37
- const o = d(p, [s, e, r]);
38
- return this.endMemory("MLP"), i.dropout && i.dropout > 0 ? L(o, i.dropout) : o;
39
- });
40
- }
41
- }
42
- export {
43
- N as default
44
- };