@genai-fi/nanogpt 0.9.1 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +55 -48
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -1,26 +1,10 @@
1
1
  import { adamAdjust as b } from "../ops/adamAdjust.js";
2
2
  import { adamMoments as d } from "../ops/adamMoments.js";
3
- import { O as g, e as h, t as o, d as B } from "../index-BzFyqcy-.js";
4
- import { z as M } from "../zeros-Bj5rMYA7.js";
5
- /**
6
- * @license
7
- * Copyright 2018 Google LLC. All Rights Reserved.
8
- * Licensed under the Apache License, Version 2.0 (the "License");
9
- * you may not use this file except in compliance with the License.
10
- * You may obtain a copy of the License at
11
- *
12
- * http://www.apache.org/licenses/LICENSE-2.0
13
- *
14
- * Unless required by applicable law or agreed to in writing, software
15
- * distributed under the License is distributed on an "AS IS" BASIS,
16
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
- * See the License for the specific language governing permissions and
18
- * limitations under the License.
19
- * =============================================================================
20
- */
3
+ import { O as g, e as h, t as o, d as B } from "../index-ZyQhjEPo.js";
4
+ import { z as M } from "../zeros-2gldETuK.js";
21
5
  class R extends g {
22
- constructor(t, a, e, s = null) {
23
- super(), this.learningRate = t, this.beta1 = a, this.beta2 = e, this.epsilon = s, this.accBeta1 = a, this.accBeta2 = e, s === null && (this.epsilon = h().backend.epsilon());
6
+ constructor(t, a, e, s, i = null) {
7
+ super(), this.learningRate = t, this.beta1 = a, this.beta2 = e, this.lossScaling = s, this.epsilon = i, this.accBeta1 = a, this.accBeta2 = e, i === null && (this.epsilon = h().backend.epsilon());
24
8
  }
25
9
  /** @nocollapse */
26
10
  static get className() {
@@ -33,19 +17,19 @@ class R extends g {
33
17
  const a = Array.isArray(t) ? t.map((e) => e.name) : Object.keys(t);
34
18
  o(() => {
35
19
  const e = 1 - this.accBeta1, s = 1 - this.accBeta2;
36
- a.forEach((n, i) => {
37
- const c = h().registeredVariables[n], u = !1;
38
- this.accumulatedMoments[i] == null && (this.accumulatedMoments[i] = {
39
- originalName: `${n}/m`,
20
+ a.forEach((i, n) => {
21
+ const c = h().registeredVariables[i], u = !1;
22
+ this.accumulatedMoments[n] == null && (this.accumulatedMoments[n] = {
23
+ originalName: `${i}/m`,
40
24
  variable: o(() => M([...c.shape, 2]).variable(u))
41
25
  });
42
- const r = Array.isArray(t) ? t[i].tensor : t[n];
26
+ const r = Array.isArray(t) ? t[n].tensor : t[i];
43
27
  if (r == null)
44
28
  return;
45
- const m = this.accumulatedMoments[i].variable, l = d(m, r, this.beta1, this.beta2);
46
- m.assign(l);
29
+ const l = this.accumulatedMoments[n].variable, m = d(l, r, this.beta1, this.beta2, this.lossScaling);
30
+ l.assign(m);
47
31
  const p = b(
48
- l,
32
+ m,
49
33
  c,
50
34
  e,
51
35
  s,
@@ -5,6 +5,7 @@ interface AdamExtConfig {
5
5
  decaySteps: number;
6
6
  minLearningRate: number;
7
7
  weightDecay?: number;
8
+ lossScaling: number;
8
9
  }
9
10
  /**
10
11
  * Extended Adam optimizer with warmup, cosine decay, and optional weight decay.
@@ -1,8 +1,8 @@
1
- import { a as r, b as c, c as h, e as o } from "../index-BzFyqcy-.js";
1
+ import { m as r, b as c, c as h, e as o } from "../index-ZyQhjEPo.js";
2
2
  import { AdamOptimizer as g } from "./Adam.js";
3
3
  class y extends g {
4
4
  constructor(t, e, s, i, a) {
5
- super(t, e, s, i), this.config = a, this.startLearningRate = t;
5
+ super(t, e, s, a.lossScaling, i), this.config = a, this.startLearningRate = t;
6
6
  }
7
7
  step = 0;
8
8
  startLearningRate;
@@ -1,23 +1,6 @@
1
- import { t as g } from "../index-BzFyqcy-.js";
2
- import { d as u, i as d } from "../dataset-DlZtKmBq.js";
3
- import "../index-Tf7vU29b.js";
4
- /**
5
- * @license
6
- * Copyright 2018 Google LLC. All Rights Reserved.
7
- * Licensed under the Apache License, Version 2.0 (the "License");
8
- * you may not use this file except in compliance with the License.
9
- * You may obtain a copy of the License at
10
- *
11
- * http://www.apache.org/licenses/LICENSE-2.0
12
- *
13
- * Unless required by applicable law or agreed to in writing, software
14
- * distributed under the License is distributed on an "AS IS" BASIS,
15
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
- * See the License for the specific language governing permissions and
17
- * limitations under the License.
18
- *
19
- * =============================================================================
20
- */
1
+ import { t as g } from "../index-ZyQhjEPo.js";
2
+ import { d as u, i as d } from "../dataset-0xP8GjwI.js";
3
+ import "../index-Cp39cXWe.js";
21
4
  function z(r) {
22
5
  return u(async () => {
23
6
  const t = await r();
@@ -1,13 +1,14 @@
1
- import L from "./Trainer.js";
2
- import w from "./Evaluator.js";
3
- import { d as S } from "../index-BzFyqcy-.js";
4
- import f from "../utilities/profile.js";
5
- const y = {
1
+ import b from "./Trainer.js";
2
+ import L from "./Evaluator.js";
3
+ import { d as w } from "../index-ZyQhjEPo.js";
4
+ import y from "../utilities/profile.js";
5
+ import { createTensorStatistics as D } from "../checks/weights.js";
6
+ const T = {
6
7
  desiredLoss: 0.01,
7
8
  logInterval: 1,
8
9
  maxSteps: 1e3
9
10
  };
10
- class x extends L {
11
+ class V extends b {
11
12
  constructor(s, e, a = 3e-4) {
12
13
  super(s, e, a);
13
14
  }
@@ -42,89 +43,95 @@ class x extends L {
42
43
  }
43
44
  async stepDataset(s, e, a) {
44
45
  const { logInterval: l } = {
45
- ...y,
46
+ ...T,
46
47
  ...e
47
48
  }, c = Date.now(), r = this.createEmptyState();
48
- this.lastState = r, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new f())), this.running = !0, r.logStartTime = c;
49
- const d = a ? new w(this.model, a) : void 0, t = await s.iterator();
49
+ this.lastState = r, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new y())), this.running = !0, r.logStartTime = c;
50
+ const g = a ? new L(this.model, a) : void 0, t = await s.iterator();
50
51
  try {
51
52
  for (; this.running; ) {
52
- const i = await t.next();
53
- if (i.done) break;
54
- const m = i.value, o = this.trainBatch(r, m);
53
+ const n = await t.next();
54
+ if (n.done) break;
55
+ const m = n.value, o = this.trainBatch(r, m);
55
56
  if (r.step % l === 0) {
56
- const g = (await o.data())[0];
57
- r.lastLoss = g;
57
+ const p = (await o.data())[0];
58
+ r.lastLoss = p;
58
59
  const u = Date.now();
59
60
  r.trainingDuration += u - r.logStartTime;
60
- const p = this.createLogEntry(r, c, m.xs.shape[0], e?.advancedMetrics);
61
+ const d = this.createLogEntry(r, c, m.xs.shape[0], e?.advancedMetrics);
61
62
  if (this.model.trainingState = {
62
63
  steps: r.totalSteps,
63
64
  learningRate: this.optimizer.lr,
64
65
  batchSize: m.xs.shape[0],
65
66
  loss: r.lastLoss
66
- }, d)
67
+ }, g)
67
68
  try {
68
- const n = await d.evaluate(5);
69
- r.validationLosses.push(n), p.valLoss = n;
70
- } catch (n) {
71
- console.error("Validation error:", n);
69
+ const h = await g.evaluate(5);
70
+ r.validationLosses.push(h), d.valLoss = h;
71
+ } catch (h) {
72
+ console.error("Validation error:", h);
72
73
  }
73
- const v = this.createProgress(r, p, e?.advancedMetrics);
74
- return o.dispose(), this.stop(), { log: p, progress: v };
74
+ const S = this.createProgress(r, d, e?.advancedMetrics);
75
+ return o.dispose(), this.stop(), { log: d, progress: S };
75
76
  }
76
77
  o.dispose();
77
78
  }
78
- } catch (i) {
79
- throw console.error("Training error:", i), S(), i;
79
+ } catch (n) {
80
+ throw console.error("Training error:", n), w(), n;
80
81
  }
81
- throw S(), this.running = !1, new Error("No log returned before training stopped.");
82
+ throw w(), this.running = !1, new Error("No log returned before training stopped.");
82
83
  }
83
84
  // Train for multiple epochs using Dataset API - FIXED memory leaks
84
85
  async trainOnDataset(s, e, a) {
85
86
  const { logInterval: l, onStep: c, maxSteps: r } = {
86
- ...y,
87
+ ...T,
87
88
  ...e
88
- }, d = Date.now(), t = this.createEmptyState();
89
- this.lastState = t, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new f())), this.running = !0, t.logStartTime = d;
90
- const i = a ? new w(this.model, a) : void 0, m = await s.iterator();
89
+ }, g = Date.now(), t = this.createEmptyState();
90
+ this.lastState = t, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new y())), console.log("Training options", e), this.running = !0, t.logStartTime = g;
91
+ const n = a ? new L(this.model, a) : void 0, m = await s.iterator();
91
92
  try {
92
93
  for (; this.running; ) {
93
94
  const o = await m.next();
94
95
  if (o.done) break;
95
- const g = o.value, u = this.trainBatch(t, g);
96
- if (t.step % l === 0) {
97
- const p = (await u.data())[0];
98
- t.lastLoss = p;
99
- const v = Date.now();
100
- t.trainingDuration += v - t.logStartTime;
101
- const n = this.createLogEntry(t, d, g.xs.shape[0], e?.advancedMetrics);
96
+ const p = o.value, u = t.step % l === 0, d = (e?.gradientMetrics || !1) && u, S = this.trainBatch(t, p, d);
97
+ if (u) {
98
+ const h = (await S.data())[0];
99
+ t.lastLoss = h;
100
+ const P = Date.now();
101
+ t.trainingDuration += P - t.logStartTime;
102
+ const f = this.createLogEntry(t, g, p.xs.shape[0], e?.advancedMetrics);
102
103
  if (this.model.trainingState = {
103
104
  steps: t.totalSteps,
104
105
  learningRate: this.optimizer.lr,
105
- batchSize: g.xs.shape[0],
106
+ batchSize: p.xs.shape[0],
106
107
  loss: t.lastLoss
107
- }, i)
108
+ }, e?.gradientMetrics && d && t.gradients) {
109
+ const i = /* @__PURE__ */ new Map();
110
+ for (const [M, v] of Object.entries(t.gradients))
111
+ i.set(M, await D(v)), v.dispose();
112
+ f.gradientMetrics = i;
113
+ }
114
+ if (n)
108
115
  try {
109
- const h = await i.evaluate(5);
110
- t.validationLosses.push(h), n.valLoss = h;
111
- } catch (h) {
112
- console.error("Validation error:", h);
116
+ const i = await n.evaluate(5);
117
+ t.validationLosses.push(i), f.valLoss = i;
118
+ } catch (i) {
119
+ console.error("Validation error:", i);
113
120
  }
114
121
  if (c) {
115
- const h = this.createProgress(t, n, e?.advancedMetrics);
116
- await c(n, h);
122
+ const i = this.createProgress(t, f, e?.advancedMetrics);
123
+ await c(f, i);
117
124
  }
118
125
  t.logStartTime = Date.now();
119
126
  }
120
- u.dispose(), t.step >= r && this.stop();
127
+ S.dispose(), t.step >= r && this.stop();
121
128
  }
122
129
  } catch (o) {
123
- throw console.error("Training error:", o), S(), o;
130
+ throw console.error("Training error:", o), w(), o;
124
131
  }
125
- return S(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
132
+ return w(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
126
133
  }
127
134
  }
128
135
  export {
129
- x as default
136
+ V as default
130
137
  };
@@ -1,10 +1,11 @@
1
1
  import { ITokeniser } from '../tokeniser/type';
2
2
  import { DatasetBuilder } from './DatasetBuilder';
3
3
  import { default as AdamExt } from './AdamExt';
4
- import { TensorContainer } from '@tensorflow/tfjs-core/dist/tensor_types';
4
+ import { NamedTensorMap, TensorContainer } from '@tensorflow/tfjs-core/dist/tensor_types';
5
5
  import { Scalar, Tensor } from '@tensorflow/tfjs-core';
6
6
  import { Dataset } from '@tensorflow/tfjs-data';
7
7
  import { default as Model, ModelForwardAttributes } from '../models/model';
8
+ import { TensorStatistics } from '../checks/weights';
8
9
  export interface TrainingLogEntry {
9
10
  loss: number;
10
11
  valLoss?: number;
@@ -12,7 +13,7 @@ export interface TrainingLogEntry {
12
13
  time: number;
13
14
  example?: string;
14
15
  batchSize: number;
15
- gradientNorm?: number;
16
+ gradientMetrics?: Map<string, TensorStatistics>;
16
17
  learningRate?: number;
17
18
  }
18
19
  export interface TrainingState {
@@ -21,7 +22,7 @@ export interface TrainingState {
21
22
  totalSteps: number;
22
23
  losses: number[];
23
24
  validationLosses: number[];
24
- gradientNorm?: Promise<number>;
25
+ gradients?: NamedTensorMap;
25
26
  }
26
27
  export interface TrainingProgress {
27
28
  duration: number;
@@ -41,6 +42,7 @@ export interface TrainingOptions {
41
42
  prompt?: string;
42
43
  maxSteps: number;
43
44
  advancedMetrics?: boolean;
45
+ gradientMetrics?: boolean;
44
46
  onStep?: (log: TrainingLogEntry, progress: TrainingProgress) => Promise<void> | void;
45
47
  }
46
48
  export default abstract class GPTTrainer {
@@ -52,22 +54,25 @@ export default abstract class GPTTrainer {
52
54
  protected running: boolean;
53
55
  protected lastState?: TrainingState;
54
56
  protected _gradientCheckpointing: boolean;
57
+ protected _mixedPrecision: boolean;
58
+ protected lossScaling: number;
55
59
  constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, learningRate?: number);
56
60
  setGradientCheckpointing(enabled: boolean): void;
61
+ setMixedPrecision(enabled: boolean): void;
57
62
  setLearningRate(learningRate: number): void;
58
63
  reset(): void;
59
64
  stop(): void;
60
65
  getOptimizer(): AdamExt;
61
66
  resetOptimizer(config?: AdamConfig): void;
62
- protected trainStep(_state: Partial<TrainingState>, batch: {
67
+ protected trainStep(state: Partial<TrainingState>, batch: {
63
68
  xs: Tensor;
64
69
  ys: Tensor;
65
- }, dummy?: boolean): Scalar;
70
+ }, dummy?: boolean, keepGrads?: boolean): Scalar;
66
71
  protected dummyPass(): Promise<void>;
67
72
  protected trainBatch(state: TrainingState, batch: {
68
73
  xs: Tensor;
69
74
  ys: Tensor;
70
- }): Scalar;
75
+ }, keepGrads?: boolean): Scalar;
71
76
  abstract trainOnDataset(dataset: Dataset<{
72
77
  xs: Tensor;
73
78
  ys: Tensor;
@@ -1,10 +1,10 @@
1
- import { DatasetBuilder as m, flattenTokens as c, PAGE_FACTOR as g } from "./DatasetBuilder.js";
2
- import u from "./AdamExt.js";
3
- import { t as f, v as y, d as p } from "../index-BzFyqcy-.js";
4
- import { z as h } from "../zeros-Bj5rMYA7.js";
5
- class x {
6
- constructor(t, e, i = 1e-3) {
7
- this.tokenizer = e, this.model = t, this.learningRate = i, this.resetOptimizer(), this.datasetBuilder = new m(e, t.config.blockSize);
1
+ import { DatasetBuilder as f, flattenTokens as h, PAGE_FACTOR as y } from "./DatasetBuilder.js";
2
+ import z from "./AdamExt.js";
3
+ import { t as S, v as k, k as x, d as p, b as m } from "../index-ZyQhjEPo.js";
4
+ import { z as g } from "../zeros-2gldETuK.js";
5
+ class M {
6
+ constructor(t, e, s = 1e-3) {
7
+ this.tokenizer = e, this.model = t, this.lossScaling = t.lossScaling, this.learningRate = s, this.resetOptimizer(), this.datasetBuilder = new f(e, t.config.blockSize);
8
8
  }
9
9
  model;
10
10
  optimizer;
@@ -13,9 +13,14 @@ class x {
13
13
  running = !1;
14
14
  lastState;
15
15
  _gradientCheckpointing = !1;
16
+ _mixedPrecision = !1;
17
+ lossScaling;
16
18
  setGradientCheckpointing(t) {
17
19
  this._gradientCheckpointing = t;
18
20
  }
21
+ setMixedPrecision(t) {
22
+ this._mixedPrecision = t;
23
+ }
19
24
  setLearningRate(t) {
20
25
  this.learningRate = t, this.resetOptimizer({ learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 });
21
26
  }
@@ -30,7 +35,7 @@ class x {
30
35
  }
31
36
  resetOptimizer(t = { learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 }) {
32
37
  this.optimizer && this.optimizer.dispose();
33
- const e = new u(
38
+ const e = new z(
34
39
  t.learningRateFactor * this.learningRate,
35
40
  t.beta1,
36
41
  t.beta2,
@@ -39,69 +44,76 @@ class x {
39
44
  warmupSteps: 100,
40
45
  decaySteps: 2e4,
41
46
  minLearningRate: 1e-4,
42
- weightDecay: 0
47
+ weightDecay: 0,
48
+ lossScaling: this.lossScaling
43
49
  }
44
50
  );
45
51
  this.optimizer = e;
46
52
  }
47
- trainStep(t, e, i = !1) {
48
- return f(() => {
53
+ trainStep(t, e, s = !1, i = !1) {
54
+ return S(() => {
49
55
  this.model.getProfiler()?.startMemory();
50
- const { xs: a, ys: s } = e, n = () => {
51
- const [l, d] = this.model.forward(
52
- { training: !0, checkpointing: this._gradientCheckpointing },
56
+ const { xs: a, ys: l } = e, c = () => {
57
+ const [n, d] = this.model.forward(
58
+ {
59
+ training: !0,
60
+ checkpointing: this._gradientCheckpointing,
61
+ mixedPrecision: this._mixedPrecision
62
+ },
53
63
  a,
54
- s
64
+ l
55
65
  );
56
- return l.dispose(), d;
57
- }, { value: o, grads: r } = y(n);
58
- return i ? this.model.getProfiler()?.endMemory("Training") : (this.optimizer.applyGradients(r), this.model.getProfiler()?.endMemory("Training"), p(r)), o;
66
+ n.dispose();
67
+ const u = d.mul(m(this.lossScaling));
68
+ return d.dispose(), u;
69
+ }, { value: o, grads: r } = k(c);
70
+ return s ? this.model.getProfiler()?.endMemory("Training") : (this.optimizer.applyGradients(r), this.model.getProfiler()?.endMemory("Training"), i ? (t.gradients = r, Object.values(r).forEach((n) => x(n))) : p(r)), o.mul(m(1 / this.lossScaling));
59
71
  });
60
72
  }
61
73
  async dummyPass() {
62
- const t = h([1, this.model.config.blockSize], "int32"), e = h([1, this.model.config.blockSize], "int32");
74
+ const t = g([1, this.model.config.blockSize], "int32"), e = g([1, this.model.config.blockSize], "int32");
63
75
  try {
64
- const i = this.trainStep({}, { xs: t, ys: e }, !0);
65
- await i.data(), i.dispose();
66
- } catch (i) {
67
- console.error("Error during dummy pass:", i);
76
+ const s = this.trainStep({}, { xs: t, ys: e }, !0);
77
+ await s.data(), s.dispose();
78
+ } catch (s) {
79
+ console.error("Error during dummy pass:", s);
68
80
  } finally {
69
81
  t.dispose(), e.dispose();
70
82
  }
71
83
  }
72
- trainBatch(t, e) {
84
+ trainBatch(t, e, s = !1) {
73
85
  try {
74
- const i = this.trainStep(t, e, !1);
86
+ const i = this.trainStep(t, e, !1, s);
75
87
  return e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++, i;
76
88
  } catch (i) {
77
89
  throw console.error(`Error processing batch at step ${t.step}:`, i), p(), i;
78
90
  }
79
91
  }
80
- async createTrainValidationSplit(t, e = 32, i = 0.1) {
81
- const a = await c(t, this.tokenizer), s = /* @__PURE__ */ new Set();
82
- if (i > 0) {
83
- const r = Math.floor(a.length / (this.datasetBuilder.blockSize * g)), l = Math.max(1, Math.floor(r * i));
84
- for (; s.size < l; ) {
85
- const d = Math.floor(Math.random() * r);
86
- s.add(d);
92
+ async createTrainValidationSplit(t, e = 32, s = 0.1) {
93
+ const i = await h(t, this.tokenizer), a = /* @__PURE__ */ new Set();
94
+ if (s > 0) {
95
+ const o = Math.floor(i.length / (this.datasetBuilder.blockSize * y)), r = Math.max(1, Math.floor(o * s));
96
+ for (; a.size < r; ) {
97
+ const n = Math.floor(Math.random() * o);
98
+ a.add(n);
87
99
  }
88
100
  }
89
- const n = await this.datasetBuilder.createTextDataset(a, e, s, !1), o = await this.datasetBuilder.createTextDataset(
90
- a,
101
+ const l = await this.datasetBuilder.createTextDataset(i, e, a, !1), c = await this.datasetBuilder.createTextDataset(
102
+ i,
91
103
  e,
92
- s,
104
+ a,
93
105
  !0
94
106
  );
95
- return { trainDataset: n, validationDataset: o };
107
+ return { trainDataset: l, validationDataset: c };
96
108
  }
97
109
  async createDataset(t, e = 32) {
98
- const i = await c(t, this.tokenizer);
99
- return await this.datasetBuilder.createTextDataset(i, e);
110
+ const s = await h(t, this.tokenizer);
111
+ return await this.datasetBuilder.createTextDataset(s, e);
100
112
  }
101
113
  dispose() {
102
114
  this.optimizer && this.optimizer.dispose();
103
115
  }
104
116
  }
105
117
  export {
106
- x as default
118
+ M as default
107
119
  };
@@ -1,8 +1,8 @@
1
1
  import { gatherSub as x } from "../ops/gatherSub.js";
2
2
  import { scatterSub as L } from "../ops/scatterSub.js";
3
- import { z as C, t as u, A as E, c as G } from "../index-BzFyqcy-.js";
4
- import { s as y } from "../softmax-D7Jj3p_P.js";
5
- import { m as z, l as v } from "../log_sum_exp-DO6z8tSE.js";
3
+ import { w as C, t as u, z as E, c as G } from "../index-ZyQhjEPo.js";
4
+ import { s as y } from "../softmax-ZHVebtR1.js";
5
+ import { m as z, l as v } from "../log_sum_exp-DWI-76TI.js";
6
6
  function k(t, s) {
7
7
  return u(() => {
8
8
  const n = t.shape[t.shape.length - 1], c = t.shape.slice(0, -1).reduce((o, e) => o * e, 1), h = t.shape.length > 2 ? t.reshape([c, n]) : t, p = s.shape.length > 1 ? s.reshape([c]).cast("int32") : s.cast("int32"), r = z(h, -1, !0), a = G(h, r), d = v(a, -1);
@@ -0,0 +1,38 @@
1
+ import { A as u, B as i, E as a, t as m } from "./index-ZyQhjEPo.js";
2
+ import { D as $, N as g, F as x, H as c } from "./tensor_util-DV-FP5Q3.js";
3
+ import { c as k } from "./complex-CSlYz-2T.js";
4
+ import { a as l } from "./tensor-DdQUJZlz.js";
5
+ function K(r) {
6
+ const e = { input: i(r, "input", "imag") };
7
+ return a.runKernel($, e);
8
+ }
9
+ const h = /* @__PURE__ */ u({ imag_: K });
10
+ function E(r) {
11
+ const e = { x: i(r, "x", "neg") };
12
+ return a.runKernel(g, e);
13
+ }
14
+ const N = /* @__PURE__ */ u({ neg_: E });
15
+ function _(r) {
16
+ const e = { input: i(r, "input", "real") };
17
+ return a.runKernel(x, e);
18
+ }
19
+ const b = /* @__PURE__ */ u({ real_: _ });
20
+ function d(r, t, e) {
21
+ const n = i(r, "x", "transpose");
22
+ if (t == null && (t = n.shape.map((s, o) => o).reverse()), l(n.rank === t.length, () => `Error in transpose: rank of input ${n.rank} must match length of perm ${t}.`), t.forEach((s) => {
23
+ l(s >= 0 && s < n.rank, () => `All entries in 'perm' must be between 0 and ${n.rank - 1} but got ${t}`);
24
+ }), n.rank <= 1)
25
+ return n.clone();
26
+ const f = { x: n }, p = { perm: t };
27
+ return n.dtype === "complex64" ? m(() => {
28
+ let s = b(n), o = h(n);
29
+ return s = a.runKernel(c, { x: s }, p), o = a.runKernel(c, { x: o }, p), e && (o = N(o)), k(s, o);
30
+ }) : a.runKernel(c, f, p);
31
+ }
32
+ const I = /* @__PURE__ */ u({ transpose_: d });
33
+ export {
34
+ h as i,
35
+ N as n,
36
+ b as r,
37
+ I as t
38
+ };
@@ -1,20 +1,20 @@
1
- function n(r, e) {
1
+ function i(r, e) {
2
2
  let t = 0;
3
- if (Array.isArray(r) && Array.isArray(e)) {
3
+ if ((Array.isArray(r) || r instanceof Float32Array) && (Array.isArray(e) || e instanceof Float32Array)) {
4
4
  if (r.length !== e.length) return Number.POSITIVE_INFINITY;
5
- for (let i = 0; i < r.length; ++i)
6
- t = Math.max(t, n(r[i], e[i]));
5
+ for (let n = 0; n < r.length; ++n)
6
+ t = Math.max(t, i(r[n], e[n]));
7
7
  return t;
8
8
  } else if (typeof r == "number" && typeof e == "number") {
9
9
  if (isNaN(r) && isNaN(e))
10
10
  return 0;
11
11
  if (!isFinite(r) || !isFinite(e))
12
12
  return r === e ? 0 : Number.POSITIVE_INFINITY;
13
- const i = Math.abs(r - e);
14
- return t = Math.max(t, i), t;
13
+ const n = Math.abs(r - e);
14
+ return t = Math.max(t, n), t;
15
15
  } else
16
16
  return Number.POSITIVE_INFINITY;
17
17
  }
18
18
  export {
19
- n as arraysClose
19
+ i as arraysClose
20
20
  };
@@ -1,35 +1,43 @@
1
- import { m as y, v as P, e as S } from "../index-BzFyqcy-.js";
2
- import { z as i } from "../zeros-Bj5rMYA7.js";
3
- async function w(s) {
4
- const t = i([1, s.config.blockSize], "int32"), [e, n] = s.forward({ training: !1 }, t);
5
- await e.data(), e.dispose(), n && n.dispose(), t.dispose();
1
+ import { a as y, e as S, v as w } from "../index-ZyQhjEPo.js";
2
+ import { z as m } from "../zeros-2gldETuK.js";
3
+ import { o as P } from "../ones-CAMiP4I2.js";
4
+ async function b(s) {
5
+ const t = m([1, s.config.blockSize], "int32"), [n, o] = s.forward({ training: !1 }, t);
6
+ await n.data(), n.dispose(), o && o.dispose(), t.dispose();
6
7
  }
7
- async function k(s) {
8
- const t = y(), e = t.numBytesInGPUAllocated ?? t.numBytesAllocatedInGPU ?? t.numBytes;
9
- await w(s);
10
- const n = i([1, s.config.blockSize], "int32"), r = i([1, s.config.blockSize], "int32"), o = {
8
+ async function G(s) {
9
+ console.log("Starting dummy training pass for memory profiling...");
10
+ const t = y(), n = t.numBytesInGPUAllocated ?? t.numBytesAllocatedInGPU ?? t.numBytes;
11
+ await b(s), console.log("Forward pass complete. Starting backward pass...");
12
+ const o = m([1, s.config.blockSize], "int32"), a = P([1, s.config.blockSize], "int32"), e = {
11
13
  perBatch: 0,
12
14
  tapeSize: 0,
13
15
  gradients: s.getNumParams() * 4
14
- }, f = () => {
15
- const [c, g] = s.forward({ training: !0 }, n, r), u = S().state.activeTape;
16
- let p = 0;
17
- if (u)
18
- for (const z of u)
19
- p += z.saved?.reduce((B, I) => B + I.size * 4, 0) || 0;
20
- return o.tapeSize = p, c.dispose(), g;
21
- }, { value: m, grads: d } = P(f), a = y(), l = a.numBytesInGPUAllocated ?? a.numBytesAllocatedInGPU ?? a.numBytes;
22
- o.perBatch = l - e - o.gradients, console.log("Dummy training memory requirements:", o), await m.data(), m.dispose();
23
- for (const c in d)
24
- d[c].dispose();
25
- return n.dispose(), r.dispose(), o;
16
+ };
17
+ try {
18
+ const i = () => {
19
+ const [c, g] = s.forward({ training: !0 }, o, a), u = S().state.activeTape;
20
+ let l = 0;
21
+ if (u)
22
+ for (const z of u)
23
+ l += z.saved?.reduce((B, I) => B + I.size * 4, 0) || 0;
24
+ return e.tapeSize = l, c.dispose(), g;
25
+ }, { value: d, grads: p } = w(i), r = y(), f = r.numBytesInGPUAllocated ?? r.numBytesAllocatedInGPU ?? r.numBytes;
26
+ e.perBatch = f - n - e.gradients, console.log("Dummy training memory requirements:", e), await d.data(), d.dispose();
27
+ for (const c in p)
28
+ p[c].dispose();
29
+ o.dispose(), a.dispose();
30
+ } catch (i) {
31
+ console.error("Error during dummy training pass:", i), o.dispose(), a.dispose();
32
+ }
33
+ return e;
26
34
  }
27
- function v(s) {
28
- const t = i([1, s.config.blockSize], "int32"), [e, n] = s.forward({ training: !1 }, t);
29
- e.dispose(), n && n.dispose(), t.dispose();
35
+ function T(s) {
36
+ const t = m([1, s.config.blockSize], "int32"), [n, o] = s.forward({ training: !1 }, t);
37
+ n.dispose(), o && o.dispose(), t.dispose();
30
38
  }
31
39
  export {
32
- v as dummyPass,
33
- w as dummyPassAsync,
34
- k as dummyPassTrainAsync
40
+ T as dummyPass,
41
+ b as dummyPassAsync,
42
+ G as dummyPassTrainAsync
35
43
  };
@@ -1,5 +1,5 @@
1
- import "../index-BzFyqcy-.js";
2
- import { t as e } from "../tensor2d-D76QGjF3.js";
1
+ import "../index-ZyQhjEPo.js";
2
+ import { t as e } from "../tensor2d-G4Ys2GxX.js";
3
3
  function l(n) {
4
4
  let r = 0;
5
5
  const i = Math.random();