@genai-fi/nanogpt 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +82 -64
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -1,26 +1,12 @@
1
- import { aq as T, ag as E, p as O, j as V, aB as B, a1 as F, ah as j, aC as K } from "./index-BzFyqcy-.js";
2
- import { r as $ } from "./Reshape-Bowtk9BP.js";
3
- import { g as A, a as C, b as k, c as N, e as R } from "./axis_util-TbGYJ208.js";
4
- import { t as U, m as W } from "./shared-DuP7ue-R.js";
5
- import { c as _ } from "./backend_util-CJIiDoV1.js";
6
- import { f as y } from "./gpgpu_math-CDaYiyE_.js";
7
- import { g as G, b as L } from "./kernel_funcs_utils-DKLK0Mg3.js";
8
- /**
9
- * @license
10
- * Copyright 2020 Google LLC. All Rights Reserved.
11
- * Licensed under the Apache License, Version 2.0 (the "License");
12
- * you may not use this file except in compliance with the License.
13
- * You may obtain a copy of the License at
14
- *
15
- * http://www.apache.org/licenses/LICENSE-2.0
16
- *
17
- * Unless required by applicable law or agreed to in writing, software
18
- * distributed under the License is distributed on an "AS IS" BASIS,
19
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
- * See the License for the specific language governing permissions and
21
- * limitations under the License.
22
- * =============================================================================
23
- */
1
+ import "./index-ZyQhjEPo.js";
2
+ import { r as $ } from "./Reshape-_kILl6tK.js";
3
+ import { _ as T, g as E, y as B, $ as F } from "./tensor_util-DV-FP5Q3.js";
4
+ import { H as _, e as K, p as O, s as V } from "./tensor-DdQUJZlz.js";
5
+ import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-BvHEw88j.js";
6
+ import { t as U, m as W } from "./shared-BRksrJb3.js";
7
+ import { c as j } from "./backend_util-D-rUb2ty.js";
8
+ import { f as y } from "./gpgpu_math-DDVJCn6-.js";
9
+ import { g as G, b as L } from "./kernel_funcs_utils-Dg_-E44D.js";
24
10
  class w {
25
11
  constructor(s, e) {
26
12
  this.variableNames = ["x"];
@@ -30,7 +16,7 @@ class w {
30
16
  let o = "sumValue += dot(values, ones);";
31
17
  if (e != null) {
32
18
  const p = 1 / e;
33
- o = `sumValue += dot(values * ${T(p) ? p.toPrecision(2) : p}, ones);`;
19
+ o = `sumValue += dot(values * ${_(p) ? p.toPrecision(2) : p}, ones);`;
34
20
  }
35
21
  let u = "";
36
22
  l % t > 0 && (u = `
@@ -89,23 +75,7 @@ class w {
89
75
  `;
90
76
  }
91
77
  }
92
- /**
93
- * @license
94
- * Copyright 2017 Google LLC. All Rights Reserved.
95
- * Licensed under the Apache License, Version 2.0 (the "License");
96
- * you may not use this file except in compliance with the License.
97
- * You may obtain a copy of the License at
98
- *
99
- * http://www.apache.org/licenses/LICENSE-2.0
100
- *
101
- * Unless required by applicable law or agreed to in writing, software
102
- * distributed under the License is distributed on an "AS IS" BASIS,
103
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104
- * See the License for the specific language governing permissions and
105
- * limitations under the License.
106
- * =============================================================================
107
- */
108
- class q {
78
+ class X {
109
79
  constructor(s, e) {
110
80
  this.variableNames = ["x"];
111
81
  const { windowSize: t, batchSize: n, inSize: l, outSize: r } = s;
@@ -213,26 +183,10 @@ class q {
213
183
  `;
214
184
  }
215
185
  }
216
- /**
217
- * @license
218
- * Copyright 2020 Google LLC. All Rights Reserved.
219
- * Licensed under the Apache License, Version 2.0 (the "License");
220
- * you may not use this file except in compliance with the License.
221
- * You may obtain a copy of the License at
222
- *
223
- * http://www.apache.org/licenses/LICENSE-2.0
224
- *
225
- * Unless required by applicable law or agreed to in writing, software
226
- * distributed under the License is distributed on an "AS IS" BASIS,
227
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
228
- * See the License for the specific language governing permissions and
229
- * limitations under the License.
230
- * =============================================================================
231
- */
232
- function X(a) {
186
+ function q(a) {
233
187
  const s = [];
234
188
  for (; s.length === 0 || s[s.length - 1].outSize !== 1; ) {
235
- const e = s.length ? s[s.length - 1].outSize : a[1], t = _(e);
189
+ const e = s.length ? s[s.length - 1].outSize : a[1], t = j(e);
236
190
  s.push({
237
191
  inSize: e,
238
192
  windowSize: t,
@@ -242,39 +196,23 @@ function X(a) {
242
196
  return s;
243
197
  }
244
198
  function P(a, s, e, t) {
245
- const n = X(a.shape);
199
+ const n = q(a.shape);
246
200
  let l = a;
247
201
  for (let r = 0; r < n.length; r++) {
248
202
  const { inSize: i, windowSize: c, outSize: o } = n[r];
249
203
  let u, p;
250
- e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new q({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
204
+ e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new X({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
251
205
  }
252
206
  return l;
253
207
  }
254
- /**
255
- * @license
256
- * Copyright 2017 Google LLC. All Rights Reserved.
257
- * Licensed under the Apache License, Version 2.0 (the "License");
258
- * you may not use this file except in compliance with the License.
259
- * You may obtain a copy of the License at
260
- *
261
- * http://www.apache.org/licenses/LICENSE-2.0
262
- *
263
- * Unless required by applicable law or agreed to in writing, software
264
- * distributed under the License is distributed on an "AS IS" BASIS,
265
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
266
- * See the License for the specific language governing permissions and
267
- * limitations under the License.
268
- * =============================================================================
269
- */
270
- class Y {
208
+ class H {
271
209
  constructor(s, e) {
272
210
  this.variableNames = ["A"];
273
211
  const t = new Array(s.length);
274
212
  for (let r = 0; r < t.length; r++)
275
213
  t[r] = s[e[r]];
276
214
  this.outputShape = t, this.rank = t.length;
277
- const n = y(this.rank), l = H(e);
215
+ const n = y(this.rank), l = Y(e);
278
216
  this.userCode = `
279
217
  void main() {
280
218
  ${n} resRC = getOutputCoords();
@@ -283,7 +221,7 @@ class Y {
283
221
  `;
284
222
  }
285
223
  }
286
- function H(a) {
224
+ function Y(a) {
287
225
  const s = a.length;
288
226
  if (s > 6)
289
227
  throw Error(`Transpose for rank ${s} is not yet supported`);
@@ -292,22 +230,6 @@ function H(a) {
292
230
  t[a[n]] = e[n];
293
231
  return t.join();
294
232
  }
295
- /**
296
- * @license
297
- * Copyright 2019 Google LLC. All Rights Reserved.
298
- * Licensed under the Apache License, Version 2.0 (the "License");
299
- * you may not use this file except in compliance with the License.
300
- * You may obtain a copy of the License at
301
- *
302
- * http://www.apache.org/licenses/LICENSE-2.0
303
- *
304
- * Unless required by applicable law or agreed to in writing, software
305
- * distributed under the License is distributed on an "AS IS" BASIS,
306
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
307
- * See the License for the specific language governing permissions and
308
- * limitations under the License.
309
- * =============================================================================
310
- */
311
233
  class J {
312
234
  constructor(s, e) {
313
235
  this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
@@ -340,115 +262,35 @@ class J {
340
262
  `;
341
263
  }
342
264
  }
343
- /**
344
- * @license
345
- * Copyright 2020 Google LLC. All Rights Reserved.
346
- * Licensed under the Apache License, Version 2.0 (the "License");
347
- * you may not use this file except in compliance with the License.
348
- * You may obtain a copy of the License at
349
- *
350
- * http://www.apache.org/licenses/LICENSE-2.0
351
- *
352
- * Unless required by applicable law or agreed to in writing, software
353
- * distributed under the License is distributed on an "AS IS" BASIS,
354
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
355
- * See the License for the specific language governing permissions and
356
- * limitations under the License.
357
- * =============================================================================
358
- */
359
265
  function D(a, s, e) {
360
- const t = E().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, s) : new Y(a.shape, s);
266
+ const t = K().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, s) : new H(a.shape, s);
361
267
  return e.runWebGLProgram(t, [a], a.dtype);
362
268
  }
363
- /**
364
- * @license
365
- * Copyright 2020 Google LLC. All Rights Reserved.
366
- * Licensed under the Apache License, Version 2.0 (the "License");
367
- * you may not use this file except in compliance with the License.
368
- * You may obtain a copy of the License at
369
- *
370
- * http://www.apache.org/licenses/LICENSE-2.0
371
- *
372
- * Unless required by applicable law or agreed to in writing, software
373
- * distributed under the License is distributed on an "AS IS" BASIS,
374
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
375
- * See the License for the specific language governing permissions and
376
- * limitations under the License.
377
- * =============================================================================
378
- */
379
269
  function Q(a, s, e, t) {
380
270
  const n = s, l = a.shape.length, r = O(n, a.shape);
381
271
  let i = r;
382
272
  const c = A(i, l), o = c != null;
383
273
  let u = a;
384
- o && (u = D(a, c, t), i = C(i.length, l)), k("sum", i, l);
274
+ o && (u = D(a, c, t), i = k(i.length, l)), C("sum", i, l);
385
275
  const [p, h] = N(u.shape, i);
386
276
  let d = p;
387
277
  e && (d = R(p, r));
388
- const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = B(a.dtype), I = P(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
278
+ const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = T(a.dtype), I = P(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
389
279
  return t.disposeIntermediateTensorInfo(x), t.disposeIntermediateTensorInfo(I), o && t.disposeIntermediateTensorInfo(u), m;
390
280
  }
391
- /**
392
- * @license
393
- * Copyright 2020 Google LLC. All Rights Reserved.
394
- * Licensed under the Apache License, Version 2.0 (the "License");
395
- * you may not use this file except in compliance with the License.
396
- * You may obtain a copy of the License at
397
- *
398
- * http://www.apache.org/licenses/LICENSE-2.0
399
- *
400
- * Unless required by applicable law or agreed to in writing, software
401
- * distributed under the License is distributed on an "AS IS" BASIS,
402
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
403
- * See the License for the specific language governing permissions and
404
- * limitations under the License.
405
- * =============================================================================
406
- */
407
281
  function Z(a) {
408
282
  const { inputs: s, backend: e, attrs: t } = a, { x: n } = s, { axis: l, keepDims: r } = t;
409
283
  return Q(n, l, r, e);
410
284
  }
411
- const pe = {
412
- kernelName: F,
285
+ const fe = {
286
+ kernelName: E,
413
287
  backendName: "webgl",
414
288
  kernelFunc: Z
415
289
  };
416
- /**
417
- * @license
418
- * Copyright 2020 Google LLC. All Rights Reserved.
419
- * Licensed under the Apache License, Version 2.0 (the "License");
420
- * you may not use this file except in compliance with the License.
421
- * You may obtain a copy of the License at
422
- *
423
- * http://www.apache.org/licenses/LICENSE-2.0
424
- *
425
- * Unless required by applicable law or agreed to in writing, software
426
- * distributed under the License is distributed on an "AS IS" BASIS,
427
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
428
- * See the License for the specific language governing permissions and
429
- * limitations under the License.
430
- * =============================================================================
431
- */
432
290
  function ee(a, s, e, t) {
433
291
  const n = V(s), r = V(a.shape) / n, i = $({ inputs: { x: a }, attrs: { shape: [r, n] }, backend: t }), c = P(i, a.dtype, "max", t), o = $({ inputs: { x: c }, attrs: { shape: e }, backend: t });
434
292
  return t.disposeIntermediateTensorInfo(i), t.disposeIntermediateTensorInfo(c), o;
435
293
  }
436
- /**
437
- * @license
438
- * Copyright 2020 Google LLC. All Rights Reserved.
439
- * Licensed under the Apache License, Version 2.0 (the "License");
440
- * you may not use this file except in compliance with the License.
441
- * You may obtain a copy of the License at
442
- *
443
- * http://www.apache.org/licenses/LICENSE-2.0
444
- *
445
- * Unless required by applicable law or agreed to in writing, software
446
- * distributed under the License is distributed on an "AS IS" BASIS,
447
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
448
- * See the License for the specific language governing permissions and
449
- * limitations under the License.
450
- * =============================================================================
451
- */
452
294
  function te(a) {
453
295
  const { inputs: s, backend: e, attrs: t } = a, { x: n } = s, { reductionIndices: l, keepDims: r } = t, i = n.shape.length, c = O(l, n.shape);
454
296
  let o = c;
@@ -465,9 +307,9 @@ function te(a) {
465
307
  M.values = z;
466
308
  } else
467
309
  d = D(n, u, e);
468
- o = C(o.length, i);
310
+ o = k(o.length, i);
469
311
  }
470
- k("max", o, i);
312
+ C("max", o, i);
471
313
  const [f, S] = N(d.shape, o);
472
314
  let g = f;
473
315
  r && (g = R(f, c));
@@ -481,27 +323,11 @@ function te(a) {
481
323
  x = ee(d, S, g, e);
482
324
  return p && e.disposeIntermediateTensorInfo(d), x;
483
325
  }
484
- const he = {
485
- kernelName: j,
326
+ const me = {
327
+ kernelName: B,
486
328
  backendName: "webgl",
487
329
  kernelFunc: te
488
330
  };
489
- /**
490
- * @license
491
- * Copyright 2020 Google LLC. All Rights Reserved.
492
- * Licensed under the Apache License, Version 2.0 (the "License");
493
- * you may not use this file except in compliance with the License.
494
- * You may obtain a copy of the License at
495
- *
496
- * http://www.apache.org/licenses/LICENSE-2.0
497
- *
498
- * Unless required by applicable law or agreed to in writing, software
499
- * distributed under the License is distributed on an "AS IS" BASIS,
500
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
501
- * See the License for the specific language governing permissions and
502
- * limitations under the License.
503
- * =============================================================================
504
- */
505
331
  const ae = `
506
332
  if (a == b) {
507
333
  return 1.0;
@@ -524,16 +350,16 @@ return a / b;`, se = `
524
350
  }
525
351
 
526
352
  return result;
527
- `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
528
- kernelName: K,
353
+ `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), xe = {
354
+ kernelName: F,
529
355
  backendName: "webgl",
530
356
  kernelFunc: ne
531
357
  };
532
358
  export {
533
359
  P as a,
534
- he as b,
535
- fe as c,
536
- pe as d,
360
+ me as b,
361
+ xe as c,
362
+ fe as d,
537
363
  te as m,
538
364
  ne as r,
539
365
  Z as s,
@@ -0,0 +1,16 @@
1
+ import "./index-ZyQhjEPo.js";
2
+ import { s as p, n as m, a as d } from "./tensor-DdQUJZlz.js";
3
+ import { b as c } from "./tensor_util-DV-FP5Q3.js";
4
+ function i(t) {
5
+ const { inputs: h, attrs: o } = t, { x: e } = h, { shape: r } = o, a = p(e.shape), s = m(r, a), n = p(s);
6
+ return d(a === n, () => `The new shape (${s}) has ${n} elements and the old shape (${e.shape}) has ${a} elements. The new shape and old shape must have the same number of elements.`), t.backend.incRef(e.dataId), { dataId: e.dataId, shape: s, dtype: e.dtype };
7
+ }
8
+ const $ = {
9
+ kernelName: c,
10
+ backendName: "webgpu",
11
+ kernelFunc: i
12
+ };
13
+ export {
14
+ $ as a,
15
+ i as r
16
+ };
@@ -0,0 +1,81 @@
1
+ import "./index-ZyQhjEPo.js";
2
+ import { u as C, g as f, a as R, b as g, c as I, d as c, e as u, i as m } from "./gpgpu_math-DDVJCn6-.js";
3
+ import { b as x } from "./tensor_util-DV-FP5Q3.js";
4
+ import { s as l, n as F, a as $ } from "./tensor-DdQUJZlz.js";
5
+ class S {
6
+ constructor(t, i) {
7
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = t, this.enableShapeUniforms = C(this.outputShape.length);
8
+ let a = "";
9
+ for (let e = 0; e < 4; e++) {
10
+ let o = "thisRC = rc;";
11
+ e % 2 === 1 && (o += "thisRC.z += 1;"), e > 1 && (o += "thisRC.y += 1;"), a += `
12
+ ${o}
13
+ ${e > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : ""}
14
+ int flatIndex = getFlatIndex(thisRC);
15
+
16
+ ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
17
+ vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
18
+
19
+ result[${e}] =
20
+ getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
21
+ ${e > 0 ? "}" : ""}
22
+ `;
23
+ }
24
+ this.userCode = `
25
+ ${b(i, this.enableShapeUniforms)}
26
+ ${this.enableShapeUniforms ? f() : R(t)}
27
+
28
+ void main() {
29
+ ivec3 rc = getOutputCoords();
30
+
31
+ vec4 result = vec4(0.);
32
+
33
+ ivec3 thisRC;
34
+ int rows = ${this.enableShapeUniforms ? "outShape[1]" : t[1]};
35
+ int cols = ${this.enableShapeUniforms ? "outShape[2]" : t[2]};
36
+
37
+ ${a}
38
+
39
+ setOutput(result);
40
+ }
41
+ `;
42
+ }
43
+ }
44
+ function b(s, t) {
45
+ return `
46
+ ivec3 inputCoordsFromReshapedOutCoords(int index) {
47
+ ${t ? g(["r", "c", "d"], "inputShape") : I(["r", "c", "d"], s)}
48
+ return ivec3(r, c, d);
49
+ }
50
+ `;
51
+ }
52
+ function v(s, t, i) {
53
+ const a = [
54
+ c(s.shape),
55
+ ...u(s.shape)
56
+ ], e = {
57
+ dtype: s.dtype,
58
+ shape: a,
59
+ dataId: s.dataId
60
+ }, o = [
61
+ c(t),
62
+ ...u(t)
63
+ ], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
64
+ return { dataId: h.dataId, shape: t, dtype: h.dtype };
65
+ }
66
+ function y(s) {
67
+ const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = l(e.shape), n = F(o, p), h = l(n);
68
+ $(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
69
+ const d = r.texData.get(e.dataId);
70
+ return d.isPacked && !m(e.shape, n) && !(d.texture !== null && m(d.shape, n)) ? v(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
71
+ }
72
+ const O = {
73
+ kernelName: x,
74
+ backendName: "webgl",
75
+ kernelFunc: y
76
+ };
77
+ export {
78
+ S as R,
79
+ O as a,
80
+ y as r
81
+ };
@@ -2,45 +2,51 @@ import { defaultConfig as d } from "./models/config.js";
2
2
  import { saveModel as l } from "./loader/save.js";
3
3
  import { loadModel as _ } from "./loader/load.js";
4
4
  import u from "./Generator.js";
5
- import f from "./Trainer.js";
6
- import { E as p } from "./index-Dwqa6Zy2.js";
5
+ import p from "./Trainer.js";
6
+ import { E as f } from "./index-DvYrXKkX.js";
7
7
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
8
- import "./index-BzFyqcy-.js";
8
+ import "./utilities/packed.js";
9
+ import "./index-ZyQhjEPo.js";
9
10
  import "./ops/cpu/attentionMask.js";
10
11
  import "./ops/webgl/attentionMask.js";
11
12
  import "./ops/grads/attentionMask.js";
12
- import "./ops/cpu/qkv.js";
13
- import "./ops/webgl/qkv.js";
14
- import "./ops/grads/qkv.js";
15
- import "./random_width-CXVRloNK.js";
16
- import "./register_all_kernels-DIGpEwcf.js";
17
- import "./index-Tf7vU29b.js";
18
- import "./dataset-DlZtKmBq.js";
13
+ import "./random_width-DY6Kk2Dl.js";
14
+ import "./register_all_kernels-Bwu1PTuU.js";
15
+ import "./index-Cp39cXWe.js";
16
+ import "./dataset-0xP8GjwI.js";
19
17
  import "./ops/cpu/rope.js";
20
18
  import "./ops/webgl/rope.js";
21
- import "./ops/grads/rope.js";
19
+ import "./rope-B5UUMsPi.js";
22
20
  import "./ops/cpu/appendCache.js";
23
21
  import "./ops/webgl/appendCache.js";
24
- import "./ops/cpu/fusedSoftmax.js";
25
- import "./ops/webgl/fusedSoftmax.js";
26
- import "./ops/grads/fusedSoftmax.js";
27
- import "./ops/cpu/matMulGelu.js";
28
- import "./ops/webgl/matMulGelu.js";
29
- import "./ops/grads/matMulGelu.js";
22
+ import "./ops/grads/softmax16.js";
23
+ import "./matMul16--R5hOwDG.js";
24
+ import "./ops/webgl/matMul16.js";
25
+ import "./ops/cpu/matMul16.js";
26
+ import "./pack16-CFUqumar.js";
27
+ import "./ops/transpose16.js";
28
+ import "./ops/reshape16.js";
29
+ import "./ops/cpu/qkv.js";
30
+ import "./ops/webgl/qkv.js";
31
+ import "./ops/grads/qkv.js";
30
32
  import "./ops/cpu/normRMS.js";
31
33
  import "./ops/webgl/normRMS.js";
32
34
  import "./ops/grads/normRMS.js";
35
+ import "./ops/grads/add16.js";
33
36
  import "./ops/cpu/gatherSub.js";
34
37
  import "./ops/webgl/gatherSub.js";
35
38
  import "./ops/cpu/scatterSub.js";
36
39
  import "./ops/webgl/scatterSub.js";
37
40
  import c from "./tokeniser/CharTokeniser.js";
38
41
  import g from "./tokeniser/bpe.js";
39
- import "./papaparse.min-C8l2Kvo1.js";
40
- import "./jszip.min-CjP2V1VV.js";
42
+ import "./papaparse.min-C0cScC2i.js";
43
+ import "./jszip.min-Bz5-11Bk.js";
44
+ import "./ops/cpu/matMulGelu.js";
45
+ import "./ops/webgl/matMulGelu.js";
46
+ import "./ops/grads/matMulGelu.js";
41
47
  import "./ops/cpu/gelu.js";
42
48
  import "./ops/webgl/gelu.js";
43
- import "./gelu-Bp_-935b.js";
49
+ import "./gelu-CNLFZWea.js";
44
50
  import "./ops/webgl/log.js";
45
51
  import "./ops/cpu/adamMoments.js";
46
52
  import "./ops/webgl/adamMoments.js";
@@ -51,7 +57,7 @@ import "./checks/normRMSGrad.js";
51
57
  import k from "./utilities/profile.js";
52
58
  import w from "./models/factory.js";
53
59
  class a {
54
- ee = new p();
60
+ ee = new f();
55
61
  _config;
56
62
  _model;
57
63
  _tokeniser;
@@ -150,7 +156,7 @@ class a {
150
156
  trainer() {
151
157
  if (!this._model || !this._tokeniser)
152
158
  throw new Error("model_or_tokeniser_not_initialized.");
153
- const t = new f(this._model, this._tokeniser);
159
+ const t = new p(this._model, this._tokeniser);
154
160
  return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
155
161
  const o = this.ee.listeners("trainStep");
156
162
  for (const s of o)
package/dist/Trainer.d.ts CHANGED
@@ -12,6 +12,8 @@ export interface ITrainerOptions {
12
12
  validationSplit?: number;
13
13
  advancedMetrics?: boolean;
14
14
  gradientCheckpointing?: boolean;
15
+ gradientMetrics?: boolean;
16
+ mixedPrecision?: boolean;
15
17
  }
16
18
  interface ExtendedTrainingProgress extends TrainingProgress {
17
19
  progress: number;
package/dist/Trainer.js CHANGED
@@ -1,4 +1,4 @@
1
- import { E as l } from "./index-Dwqa6Zy2.js";
1
+ import { E as l } from "./index-DvYrXKkX.js";
2
2
  import h from "./training/FullTrainer.js";
3
3
  class m extends l {
4
4
  trainer;
@@ -28,7 +28,7 @@ class m extends l {
28
28
  async train(t) {
29
29
  if (!this.trainDataset || !this.validationDataset)
30
30
  throw new Error("Datasets not prepared");
31
- this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), await this.trainer.trainOnDataset(
31
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(t?.mixedPrecision || !1), await this.trainer.trainOnDataset(
32
32
  this.trainDataset,
33
33
  {
34
34
  prompt: t?.prompt,
@@ -36,6 +36,7 @@ class m extends l {
36
36
  desiredLoss: t?.desiredLoss || 0.01,
37
37
  maxSteps: t?.maxSteps || 1e3,
38
38
  advancedMetrics: t?.advancedMetrics || !1,
39
+ gradientMetrics: t?.gradientMetrics || !1,
39
40
  onStep: async (e, a) => {
40
41
  this.log.push(e), this.progress = {
41
42
  ...a,
@@ -1,20 +1,4 @@
1
- import { n as c } from "./index-BzFyqcy-.js";
2
- /**
3
- * @license
4
- * Copyright 2017 Google LLC. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 (the "License");
6
- * you may not use this file except in compliance with the License.
7
- * You may obtain a copy of the License at
8
- *
9
- * http://www.apache.org/licenses/LICENSE-2.0
10
- *
11
- * Unless required by applicable law or agreed to in writing, software
12
- * distributed under the License is distributed on an "AS IS" BASIS,
13
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- * See the License for the specific language governing permissions and
15
- * limitations under the License.
16
- * =============================================================================
17
- */
1
+ import { a as c } from "./tensor-DdQUJZlz.js";
18
2
  function i(e, n) {
19
3
  for (let t = 0; t < e.length; ++t)
20
4
  if (e[e.length - t - 1] !== n - 1 - t)
@@ -60,12 +44,12 @@ function x(e, n) {
60
44
  return t;
61
45
  }
62
46
  export {
63
- x as a,
64
- m as b,
47
+ d as a,
48
+ x as b,
65
49
  l as c,
66
- i as d,
50
+ m as d,
67
51
  h as e,
68
- a as f,
69
- d as g,
70
- g as h
52
+ i as f,
53
+ g,
54
+ a as h
71
55
  };
package/dist/backend.d.ts CHANGED
@@ -1 +1,2 @@
1
- export declare function selectBackend(backendName: 'cpu' | 'webgl' | 'webgpu'): Promise<void>;
1
+ import { GPUOptions } from './patches/webgpu_base';
2
+ export declare function selectBackend(backendName: 'cpu' | 'webgl' | 'webgpu', options?: GPUOptions): Promise<void>;
package/dist/backend.js CHANGED
@@ -1,7 +1,13 @@
1
- import { g as a, s as i, r as o } from "./index-BzFyqcy-.js";
2
- async function e(t) {
3
- a() !== t && (t === "webgpu" && (await import("./index-C1rx_Ajs.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
1
+ import { g as o, s as e, r as s } from "./index-ZyQhjEPo.js";
2
+ async function c(t, a) {
3
+ if (o() !== t) {
4
+ if (t === "webgpu") {
5
+ const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
6
+ i(a), await import("./index-CjOj7j-u.js"), await import("./ops/webgpu/index.js");
7
+ }
8
+ await e(t), await s(), console.log(`Backend set to ${t}`);
9
+ }
4
10
  }
5
11
  export {
6
- e as selectBackend
12
+ c as selectBackend
7
13
  };