@genai-fi/nanogpt 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +82 -64
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
package/README.md CHANGED
@@ -1,28 +1,366 @@
1
1
  # GenAI NanoGPT
2
2
 
3
- Developed as a part of the Finnish Generation AI research project. This is an implementation of [NanoGPT](https://github.com/karpathy/nanoGPT) for Tensorflow.js. It allows GPT models to be training and loaded within a web browser and exposes some XAI functionality.
3
+ A browser-native implementation of GPT language models built on TensorFlow.js, developed as part of the Finnish Generation AI research project. This library enables training, fine-tuning, and inference of transformer-based language models entirely in the browser with support for explainable AI (XAI) features. It is intended to be used as an educational tool for learning about the model training process since it targets mostly tiny models. In principle it could be adapted to load other pre-trained models from Hugging Face.
4
4
 
5
- Work in progress...
5
+ Live version available here: https://lm.gen-ai.fi
6
6
 
7
- # Install
7
+ ## Overview
8
8
 
9
- ```
9
+ GenAI NanoGPT is inspired by [Andrej Karpathy's NanoGPT](https://github.com/karpathy/nanoGPT) but reimagined for the browser using TensorFlow.js. It provides a complete pipeline for:
10
+
11
+ - **Training** language models from scratch in the browser
12
+ - **Loading** pre-trained models from various sources (Hugging Face, local files)
13
+ - **Generating** text efficiently on a wide range of devices
14
+ - **Analyzing** model behavior through attention visualization and embeddings
15
+ - **Optimizing** performance across CPU, WebGL, and WebGPU backends
16
+
17
+ ### Key Features
18
+
19
+ - 🚀 **Browser-Native**: No server required - train and run models entirely client-side
20
+ - 📱 **Works on Small Devices**: Train models on iPads, phones, and Chromebooks - no powerful hardware needed
21
+ - 🎯 **Multiple Backends**: Automatic backend selection (CPU, WebGL, WebGPU) for optimal performance
22
+ - 🔧 **Flexible Tokenization**: Support for both character-level and BPE tokenizers
23
+ - 📊 **XAI Support**: Attention score visualization, gradient analysis, and embedding extraction
24
+ - 💾 **Model Persistence**: Save and load models in SafeTensors format
25
+ - ⚡ **Performance Optimizations**: Custom WebGPU kernels, gradient checkpointing, and mixed precision training
26
+ - 🎨 **Real-time Training**: Live training metrics and generation during training
27
+
28
+ ## Installation
29
+
30
+ ```bash
10
31
  npm install @genai-fi/nanogpt
11
32
  ```
12
33
 
13
- # Usage
34
+ ## Quick Start
35
+
36
+ ### Creating and Training a Model
37
+
38
+ ```javascript
39
+ import { TeachableLLM, selectBackend } from '@genai-fi/nanogpt';
40
+
41
+ // Select the best available backend
42
+ await selectBackend('webgpu'); // or 'webgl', 'cpu'
43
+
44
+ // Create a new model
45
+ const model = TeachableLLM.create('char', {
46
+ vocabSize: 200,
47
+ blockSize: 128, // Context window size
48
+ nLayer: 4, // Number of transformer layers
49
+ nHead: 4, // Number of attention heads
50
+ nEmbed: 192, // Embedding dimension
51
+ dropout: 0.1,
52
+ useRope: true, // Use Rotary Position Embeddings
53
+ });
54
+
55
+ // Training data
56
+ const trainingText = [
57
+ 'The quick brown fox jumps over the lazy dog.',
58
+ 'A journey of a thousand miles begins with a single step.',
59
+ // ... more text
60
+ ];
61
+
62
+ // Train the model
63
+ await model.train(trainingText, {
64
+ batchSize: 16,
65
+ learningRate: 3e-4,
66
+ maxSteps: 1000,
67
+ logInterval: 10,
68
+ validationSplit: 0.1,
69
+ });
14
70
 
71
+ // Generate text
72
+ const output = await model.generateText('Once upon a time', {
73
+ maxLength: 100,
74
+ temperature: 0.8,
75
+ topP: 0.9,
76
+ });
77
+
78
+ console.log(output);
79
+ ```
80
+
81
+ ### Loading a Pre-trained Model
82
+
83
+ ```javascript
84
+ import { TeachableLLM, waitForModel } from '@genai-fi/nanogpt';
85
+
86
+ // Load from Hugging Face
87
+ const model = TeachableLLM.loadModel('username/model-name');
88
+
89
+ // Or load from a file
90
+ const fileInput = document.getElementById('fileInput');
91
+ fileInput.addEventListener('change', async (event) => {
92
+ const file = event.target.files[0];
93
+ const model = TeachableLLM.loadModel(file);
94
+ await waitForModel(model);
95
+
96
+ const text = await model.generateText('Hello');
97
+ console.log(text);
98
+ });
15
99
  ```
16
- import { TeachableLLM, CharTokeniser } from '@genai-fi/nanogpt';
17
- import * as tf from '@tensorflow/tfjs';
18
100
 
19
- const tokeniser = new CharTokeniser();
20
- const model = TeachableLLM.create(tf, tokeniser, {
101
+ ## Event Handlers and Real-time Updates
102
+
103
+ ### Monitoring Training Progress
104
+
105
+ Track training metrics in real-time with event handlers:
106
+
107
+ ```javascript
108
+ const model = TeachableLLM.create('char', config);
109
+
110
+ // Listen for training step updates
111
+ model.on('trainStep', (step, progress) => {
112
+ console.log(`Step ${step.step}/${progress.totalSteps}`);
113
+ console.log(`Loss: ${step.loss.toFixed(4)}`);
114
+ console.log(`Validation Loss: ${step.valLoss?.toFixed(4) || 'N/A'}`);
115
+ console.log(`Progress: ${(progress.progress * 100).toFixed(1)}%`);
116
+ console.log(`Time Remaining: ${progress.timeRemaining}s`);
117
+
118
+ // Update UI progress bar
119
+ updateProgressBar(progress.progress);
120
+ updateLossChart(step.loss, step.valLoss);
121
+ });
122
+
123
+ await model.train(trainingText, options);
124
+ ```
125
+
126
+ ### Real-time Token Generation
127
+
128
+ Stream generated tokens as they're produced:
129
+
130
+ ```javascript
131
+ const generator = model.generator();
132
+
133
+ // Listen for generated tokens
134
+ generator.on('tokens', (tokens) => {
135
+ // tokens is an array of new token IDs
136
+ const text = model.tokeniser.decode(tokens);
137
+ console.log('New tokens:', text);
138
+
139
+ // Update UI incrementally
140
+ appendToOutput(text);
141
+ });
142
+
143
+ // Generation lifecycle events
144
+ generator.on('start', () => {
145
+ console.log('Generation started');
146
+ showSpinner();
147
+ });
148
+
149
+ generator.on('stop', () => {
150
+ console.log('Generation complete');
151
+ hideSpinner();
152
+ });
153
+
154
+ generator.on('error', (error) => {
155
+ console.error('Generation error:', error);
156
+ });
157
+
158
+ // Start generation
159
+ await generator.generate('Once upon a time', {
160
+ maxLength: 200,
161
+ temperature: 0.8,
162
+ });
163
+ ```
164
+
165
+ ## Training on Small Devices
166
+
167
+ GenAI NanoGPT is designed to work efficiently on resource-constrained devices like iPads, phones, and Chromebooks:
168
+
169
+ ### Recommended Settings for Small Devices
170
+
171
+ ```javascript
172
+ // Smaller model configuration for mobile devices
173
+ const mobileModel = TeachableLLM.create('char', {
21
174
  vocabSize: 200,
22
- blockSize: 128,
23
- nLayer: 4,
24
- nHead: 3,
25
- nEmbed: 192,
26
- dropout: 0.0,
175
+ blockSize: 128, // Smaller context window
176
+ nLayer: 4, // Fewer layers
177
+ nHead: 3, // Fewer attention heads
178
+ nEmbed: 192, // Smaller embeddings
179
+ });
180
+
181
+ // Training options optimized for limited memory
182
+ await mobileModel.train(trainingText, {
183
+ batchSize: 8, // Smaller batch size
184
+ learningRate: 3e-4,
185
+ maxSteps: 500,
186
+ validationSplit: 0.1,
187
+ logInterval: 50,
188
+ gradientCheckpointing: true,
189
+ mixedPrecision: true,
190
+ });
191
+ ```
192
+
193
+ ### Tips for Training on Mobile Devices
194
+
195
+ 1. **Start Small**: Use smaller models (4 layers) and shorter context windows (128 tokens)
196
+ 2. **Reduce Batch Size**: Use batch sizes of 8-16 depending on available memory
197
+ 3. **Use Character Tokenization**: Character-level tokenizers use less memory than BPE
198
+ 4. **Optimize Training Data**: Use smaller datasets or train in stages
199
+
200
+ ## Advanced Usage
201
+
202
+ ### Attention Visualization
203
+
204
+ ```javascript
205
+ const generator = model.generator();
206
+
207
+ const text = await generator.generate('Prompt', {
208
+ attentionScores: true,
209
+ maxLength: 50,
27
210
  });
211
+
212
+ // Get attention data for visualization
213
+ const attentionData = generator.getAttentionData();
214
+ // Shape: [num_tokens][num_layers][num_heads][seq_len][seq_len]
215
+
216
+ const probabilities = generator.getProbabilitiesData();
217
+ // Shape: [num_tokens][seq_len][vocab_size]
218
+ ```
219
+
220
+ ### Streaming Generation
221
+
222
+ ```javascript
223
+ const generator = model.generator();
224
+
225
+ generator.on('tokens', (tokens) => {
226
+ // Update UI with new tokens in real-time
227
+ updateDisplay(tokens);
228
+ });
229
+
230
+ generator.on('start', () => console.log('Generation started'));
231
+ generator.on('stop', () => console.log('Generation complete'));
232
+
233
+ await generator.generate('Once upon a time', {
234
+ maxLength: 200,
235
+ });
236
+ ```
237
+
238
+ ### Memory Management
239
+
240
+ ```javascript
241
+ // Enable profiling
242
+ model.enableProfiler = true;
243
+
244
+ // After training/generation
245
+ const profiler = model.getProfiler();
246
+ if (profiler) {
247
+ console.log('Memory stats:', profiler.getStats());
248
+ }
249
+
250
+ // Clean up
251
+ model.dispose();
252
+ ```
253
+
254
+ ## Examples
255
+
256
+ See the [`browser-tests`](browser-tests/) directory for complete examples:
257
+
258
+ - [`generate.html`](browser-tests/generate.html): Text generation with UI
259
+ - [`rope-train.html`](browser-tests/rope-train.html): Training a model with RoPE
260
+ - [`hf.html`](browser-tests/hf.html): Loading from Hugging Face
261
+ - [`loader.html`](browser-tests/loader.html): Loading different file formats
262
+ - [`perf.html`](browser-tests/perf.html): Performance testing
263
+
264
+ ## Development
265
+
266
+ ### Setup
267
+
268
+ ```bash
269
+ git clone https://github.com/knicos/genai-nanogpt.git
270
+ cd genai-nanogpt
271
+ npm install
272
+ ```
273
+
274
+ ### Building
275
+
276
+ ```bash
277
+ npm run build # Build for production
278
+ npm run dev # Development mode with watch
279
+ ```
280
+
281
+ ### Testing
282
+
283
+ ```bash
284
+ npm test # Run all tests
285
+ ```
286
+
287
+ ### Browser Tests
288
+
289
+ ```bash
290
+ npm run test:gl # Start dev server
291
+ ```
292
+
293
+ ### Project Structure
294
+
295
+ ```
296
+ lib/
297
+ ├── models/ # Model architectures (NanoGPT)
298
+ ├── layers/ # Transformer layers (attention, MLP, etc.)
299
+ ├── ops/ # Custom TensorFlow.js operations
300
+ │ ├── cpu/ # CPU kernels
301
+ │ ├── webgl/ # WebGL kernels
302
+ │ └── webgpu/ # WebGPU kernels
303
+ ├── training/ # Training utilities and optimizers
304
+ ├── tokeniser/ # Tokenization implementations
305
+ ├── loader/ # Model loading/saving
306
+ ├── utilities/ # Helper functions
307
+ └── TeachableLLM.ts # Main API
308
+ ```
309
+
310
+ ### Custom Operations
311
+
312
+ This library implements several custom TensorFlow.js operations optimized for transformer models:
313
+
314
+ - **RoPE**: Rotary Position Embeddings
315
+ - **Attention Mask**: Causal attention masking
316
+ - **RMS Norm**: Root Mean Square normalization
317
+ - **Adam Optimizer**: Extended Adam with weight decay
318
+ - **16-bit Operators**: To enable mixed-precision training
319
+
320
+ See [`lib/ops`](lib/ops/) for implementations.
321
+
322
+ ### Contributing
323
+
324
+ 1. Fork the repository
325
+ 2. Create a feature branch: `git checkout -b feature/amazing-feature`
326
+ 3. Commit your changes: `git commit -m 'Add amazing feature'`
327
+ 4. Push to the branch: `git push origin feature/amazing-feature`
328
+ 5. Open a Pull Request
329
+
330
+ ### Code Style
331
+
332
+ This project uses ESLint and Prettier for code formatting:
333
+
334
+ ```bash
335
+ npm run lint # Check code style
336
+ ```
337
+
338
+ ## Performance Tips
339
+
340
+ 1. **Use WebGPU**: Provides the best performance for training and inference
341
+ 2. **Batch Size**: Larger batches improve GPU utilization but require more memory
342
+ 3. **Mixed Precision**: Enable for faster training on supported hardware (coming soon)
343
+ 4. **Gradient Checkpointing**: Reduce memory usage during training, but slower
344
+ 5. **Use RoPE**: More efficient than absolute position embeddings
345
+ 6. **Start Small on Mobile**: Use 2-4 layers and batch size 2-8 on phones/tablets
346
+
347
+ ## Acknowledgments
348
+
349
+ - Inspired by [Andrej Karpathy's NanoGPT](https://github.com/karpathy/nanoGPT)
350
+ - Built with [TensorFlow.js](https://www.tensorflow.org/js)
351
+ - Developed as part of the Finnish [Generation AI research project](https://generation-ai-stn.fi)
352
+
353
+ ## Citation
354
+
355
+ If you use this library in your research, please cite:
356
+
357
+ ```bibtex
358
+ @inproceedings{10.1145/3769994.3770061,
359
+ author = {Pope, Nicolas and Tedre, Matti},
360
+ title = {A Teachable Machine for Transformers},
361
+ year = {2025},
362
+ publisher = {Association for Computing Machinery},
363
+ doi = {10.1145/3769994.3770061},
364
+ booktitle = {Proceedings of the 25th Koli Calling International Conference on Computing Education Research},
365
+ }
28
366
  ```
package/dist/Generator.js CHANGED
@@ -1,82 +1,73 @@
1
- import { E as C } from "./index-Dwqa6Zy2.js";
2
- import { E as _, F as I, G as O, a6 as R, t as q, k as K } from "./index-BzFyqcy-.js";
1
+ import { E as C } from "./index-DvYrXKkX.js";
2
+ import { A as _, B as I, E as O, t as R, k as q } from "./index-ZyQhjEPo.js";
3
+ import "./utilities/packed.js";
3
4
  import "./ops/cpu/attentionMask.js";
4
5
  import "./ops/webgl/attentionMask.js";
5
6
  import "./ops/grads/attentionMask.js";
6
- import "./ops/cpu/qkv.js";
7
- import "./ops/webgl/qkv.js";
8
- import "./ops/grads/qkv.js";
9
- import { p as j } from "./random_width-CXVRloNK.js";
10
- import { t as G } from "./register_all_kernels-DIGpEwcf.js";
11
- import "./index-Tf7vU29b.js";
12
- import "./dataset-DlZtKmBq.js";
7
+ import { p as K } from "./random_width-DY6Kk2Dl.js";
8
+ import { t as j } from "./register_all_kernels-Bwu1PTuU.js";
9
+ import "./index-Cp39cXWe.js";
10
+ import "./dataset-0xP8GjwI.js";
13
11
  import "./ops/cpu/rope.js";
14
12
  import "./ops/webgl/rope.js";
15
- import "./ops/grads/rope.js";
13
+ import "./rope-B5UUMsPi.js";
16
14
  import "./ops/cpu/appendCache.js";
17
15
  import "./ops/webgl/appendCache.js";
18
- import "./ops/cpu/fusedSoftmax.js";
19
- import "./ops/webgl/fusedSoftmax.js";
20
- import "./ops/grads/fusedSoftmax.js";
21
- import "./ops/cpu/matMulGelu.js";
22
- import "./ops/webgl/matMulGelu.js";
23
- import "./ops/grads/matMulGelu.js";
16
+ import "./ops/grads/softmax16.js";
17
+ import "./matMul16--R5hOwDG.js";
18
+ import "./ops/webgl/matMul16.js";
19
+ import "./ops/cpu/matMul16.js";
20
+ import "./pack16-CFUqumar.js";
21
+ import "./ops/transpose16.js";
22
+ import "./ops/reshape16.js";
23
+ import "./ops/cpu/qkv.js";
24
+ import "./ops/webgl/qkv.js";
25
+ import "./ops/grads/qkv.js";
24
26
  import "./ops/cpu/normRMS.js";
25
27
  import "./ops/webgl/normRMS.js";
26
28
  import "./ops/grads/normRMS.js";
29
+ import "./ops/grads/add16.js";
27
30
  import { sparseSoftmaxCrossEntropy as V } from "./training/sparseCrossEntropy.js";
28
- import "./jszip.min-CjP2V1VV.js";
31
+ import "./jszip.min-Bz5-11Bk.js";
29
32
  import $ from "./tokeniser/CharTokeniser.js";
30
33
  import "./ops/cpu/adamAdjust.js";
31
34
  import "./ops/webgl/adamAdjust.js";
32
35
  import "./ops/cpu/adamMoments.js";
33
36
  import "./ops/webgl/adamMoments.js";
34
- import "./papaparse.min-C8l2Kvo1.js";
35
- import M from "./utilities/topP.js";
37
+ import "./papaparse.min-C0cScC2i.js";
38
+ import G from "./utilities/topP.js";
36
39
  import "./ops/cpu/scatterSub.js";
37
40
  import "./ops/webgl/scatterSub.js";
38
41
  import "./ops/cpu/gatherSub.js";
39
42
  import "./ops/webgl/gatherSub.js";
43
+ import "./ops/cpu/matMulGelu.js";
44
+ import "./ops/webgl/matMulGelu.js";
45
+ import "./ops/grads/matMulGelu.js";
40
46
  import "./ops/cpu/gelu.js";
41
47
  import "./ops/webgl/gelu.js";
42
- import "./gelu-Bp_-935b.js";
48
+ import "./gelu-CNLFZWea.js";
43
49
  import "./ops/webgl/log.js";
44
50
  import "./checks/normRMS.js";
45
51
  import "./checks/normRMSGrad.js";
46
- import N from "./utilities/multinomialCPU.js";
47
- import { r as E } from "./reshape-CnIwVG1c.js";
48
- import { t as P } from "./tensor2d-D76QGjF3.js";
49
- import { s as S } from "./softmax-D7Jj3p_P.js";
50
- import { g as F } from "./gather-Dnpgw-YQ.js";
51
- import { c as H } from "./concat-B912vBbo.js";
52
- /**
53
- * @license
54
- * Copyright 2020 Google LLC. All Rights Reserved.
55
- * Licensed under the Apache License, Version 2.0 (the "License");
56
- * you may not use this file except in compliance with the License.
57
- * You may obtain a copy of the License at
58
- *
59
- * http://www.apache.org/licenses/LICENSE-2.0
60
- *
61
- * Unless required by applicable law or agreed to in writing, software
62
- * distributed under the License is distributed on an "AS IS" BASIS,
63
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
64
- * See the License for the specific language governing permissions and
65
- * limitations under the License.
66
- * =============================================================================
67
- */
68
- function U(p, t, s, e = !1) {
69
- const o = I(p, "logits", "multinomial"), i = o.size, c = o.rank;
52
+ import M from "./utilities/multinomialCPU.js";
53
+ import { i as N } from "./tensor_util-DV-FP5Q3.js";
54
+ import { r as E } from "./reshape-DevtBWtf.js";
55
+ import { t as P } from "./tensor2d-G4Ys2GxX.js";
56
+ import { s as S } from "./softmax-ZHVebtR1.js";
57
+ import { g as B } from "./gather-DykLGqmW.js";
58
+ import { c as H } from "./concat-BHlIJeyT.js";
59
+ function U(l, t, s, e = !1) {
60
+ const o = I(l, "logits", "multinomial"), i = o.size, c = o.rank;
70
61
  if (i < 2)
71
62
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${i}.`);
72
63
  if (c > 2)
73
64
  throw new Error(`Rank of probabilities must be 1 or 2, but is ${c}`);
74
65
  s = s || Math.random();
75
- const n = { logits: c === 1 ? E(o, [1, -1]) : o }, l = { numSamples: t, seed: s, normalized: e }, d = O.runKernel(R, n, l);
66
+ const n = { logits: c === 1 ? E(o, [1, -1]) : o }, p = { numSamples: t, seed: s, normalized: e }, d = O.runKernel(N, n, p);
76
67
  return c === 1 ? E(d, [d.size]) : d;
77
68
  }
78
69
  const z = /* @__PURE__ */ _({ multinomial_: U }), W = [
79
- ...Array.from({ length: 95 }, (p, t) => String.fromCharCode(t + 32)),
70
+ ...Array.from({ length: 95 }, (l, t) => String.fromCharCode(t + 32)),
80
71
  // ASCII
81
72
  // Spanish accented letters and punctuation
82
73
  ..."áéíóúüñ¿¡",
@@ -87,10 +78,10 @@ const z = /* @__PURE__ */ _({ multinomial_: U }), W = [
87
78
  // Cyrillic letters
88
79
  ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
89
80
  ];
90
- function B(p, t) {
91
- return p.length === t ? p : p.length > t ? p.slice(0, t) : p.concat(Array(t - p.length).fill(""));
81
+ function F(l, t) {
82
+ return l.length === t ? l : l.length > t ? l.slice(0, t) : l.concat(Array(t - l.length).fill(""));
92
83
  }
93
- class Wt extends C {
84
+ class te extends C {
94
85
  constructor(t, s) {
95
86
  super(), this.model = t, this.tokeniser = s, this.actualTokeniser = s;
96
87
  }
@@ -116,7 +107,7 @@ class Wt extends C {
116
107
  const c = await t.decode([i]);
117
108
  if (e) {
118
109
  const T = await Promise.all(
119
- e.map((n) => n.array().then((l) => l))
110
+ e.map((n) => n.array().then((p) => p))
120
111
  );
121
112
  e.forEach((n) => n.dispose()), this.attentionData.push(T);
122
113
  }
@@ -131,14 +122,14 @@ class Wt extends C {
131
122
  } : void 0,
132
123
  cache: s,
133
124
  outputEmbeddings: !!e?.embeddings
134
- }, [l, d] = q(() => {
135
- const a = t, m = a.shape[1], h = m <= this.model.config.blockSize ? a : a.slice(
125
+ }, [p, d] = R(() => {
126
+ const r = t, m = r.shape[1], h = m <= this.model.config.blockSize ? r : r.slice(
136
127
  [0, m - this.model.config.blockSize],
137
- [a.shape[0], this.model.config.blockSize]
138
- ), r = T ? this.model.config.blockSize - h.shape[1] : 0, v = r > 0 ? j(h, [
128
+ [r.shape[0], this.model.config.blockSize]
129
+ ), a = T ? this.model.config.blockSize - h.shape[1] : 0, v = a > 0 ? K(h, [
139
130
  [0, 0],
140
- [0, r]
141
- ]) : h, [g] = this.model.forward(n, v), u = g.shape[1] - 1 - r, f = g.slice([0, u, 0], [g.shape[0], 1, g.shape[2]]);
131
+ [0, a]
132
+ ]) : h, [g] = this.model.forward(n, v), u = g.shape[1] - 1 - a, f = g.slice([0, u, 0], [g.shape[0], 1, g.shape[2]]);
142
133
  let y;
143
134
  if (e?.targets) {
144
135
  const k = e.targets.shift();
@@ -148,46 +139,46 @@ class Wt extends C {
148
139
  }
149
140
  }
150
141
  return n.attentionScores?.attentionOut && n.attentionScores.attentionOut.forEach((k, w) => {
151
- k.shape[1] !== 1 && (n.attentionScores.attentionOut[w] = K(
142
+ k.shape[1] !== 1 && (n.attentionScores.attentionOut[w] = q(
152
143
  k.slice([0, u, 0], [k.shape[0], 1, k.shape[2]])
153
144
  ), k.dispose());
154
145
  }), g.dispose(), [f.div(o).squeeze([1]), y];
155
146
  });
156
147
  let b, x;
157
148
  if (c) {
158
- const a = S(l), m = await a.array();
159
- a.dispose();
160
- const h = M(m, c);
161
- e?.includeProbabilities && (x = m), b = N(h);
149
+ const r = S(p), m = await r.array();
150
+ r.dispose();
151
+ const h = G(m, c);
152
+ e?.includeProbabilities && (x = m), b = M(h);
162
153
  } else if (i) {
163
- const { values: a, indices: m } = G(l, i), h = z(a, 1);
164
- b = F(m, h, 1), a.dispose(), m.dispose(), h.dispose();
165
- } else if (b = z(l, 1), e?.includeProbabilities) {
166
- const a = S(l);
167
- x = await a.array(), a.dispose();
154
+ const { values: r, indices: m } = j(p, i), h = z(r, 1);
155
+ b = B(m, h, 1), r.dispose(), m.dispose(), h.dispose();
156
+ } else if (b = z(p, 1), e?.includeProbabilities) {
157
+ const r = S(p);
158
+ x = await r.array(), r.dispose();
168
159
  }
169
160
  if (n.embeddings) {
170
- const m = (e?.embeddings === "all" ? n.embeddings : n.embeddings.filter((r) => r.name.startsWith("block_output_"))).map(async (r) => {
171
- const v = r.tensor.shape[1], g = r.tensor.slice([0, v - 1, 0], [r.tensor.shape[0], 1, r.tensor.shape[2]]);
172
- r.tensor.dispose();
161
+ const m = (e?.embeddings === "all" ? n.embeddings : n.embeddings.filter((a) => a.name.startsWith("block_output_"))).map(async (a) => {
162
+ const v = a.tensor.shape[1], g = a.tensor.slice([0, v - 1, 0], [a.tensor.shape[0], 1, a.tensor.shape[2]]);
163
+ a.tensor.dispose();
173
164
  const u = g.squeeze([1]);
174
165
  if (g.dispose(), e?.embeddings === "softmax") {
175
166
  const f = this.model.project(u);
176
167
  u.dispose();
177
168
  const y = S(f, -1);
178
- return f.dispose(), { name: r.name, tensor: await y.array() };
169
+ return f.dispose(), { name: a.name, tensor: await y.array() };
179
170
  } else if (e?.embeddings === "logits") {
180
171
  const f = this.model.project(u);
181
- return u.dispose(), { name: r.name, tensor: await f.array() };
172
+ return u.dispose(), { name: a.name, tensor: await f.array() };
182
173
  } else {
183
174
  const f = await u.array();
184
- return u.dispose(), { name: r.name, tensor: f };
175
+ return u.dispose(), { name: a.name, tensor: f };
185
176
  }
186
177
  }), h = await Promise.all(m);
187
178
  this.embeddingsData.push(h);
188
179
  }
189
180
  const A = b.reshape([1, 1]);
190
- b.dispose(), b = A, l.dispose();
181
+ b.dispose(), b = A, p.dispose();
191
182
  let L;
192
183
  return d && (L = await d.array(), d.dispose()), { output: b, probabilities: x, attention: n.attentionScores?.attentionOut, loss: L };
193
184
  }
@@ -211,10 +202,10 @@ class Wt extends C {
211
202
  const d = s;
212
203
  s = H([s, i], 1), d.dispose();
213
204
  }
214
- const l = await this.processResponse(this.actualTokeniser, i, T, c);
215
- if (this.cache || i.dispose(), l === null)
205
+ const p = await this.processResponse(this.actualTokeniser, i, T, c);
206
+ if (this.cache || i.dispose(), p === null)
216
207
  break;
217
- this.outputText += l;
208
+ this.outputText += p;
218
209
  }
219
210
  return s.dispose(), this.outputText;
220
211
  }
@@ -233,7 +224,7 @@ class Wt extends C {
233
224
  o[i] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
234
225
  this.cache = o, this.lastToken = -1;
235
226
  }
236
- const e = this.tokeniser.trained ? this.tokeniser : new $(B(W, this.tokeniser.vocabSize));
227
+ const e = this.tokeniser.trained ? this.tokeniser : new $(F(W, this.tokeniser.vocabSize));
237
228
  this.actualTokeniser = e;
238
229
  }
239
230
  async step(t, s) {
@@ -268,5 +259,5 @@ class Wt extends C {
268
259
  }
269
260
  }
270
261
  export {
271
- Wt as default
262
+ te as default
272
263
  };