@genai-fi/nanogpt 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +82 -64
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -1,12 +1,14 @@
1
- import { defaultConfig as u } from "./config.js";
2
- import a from "../layers/TransformerBlock.js";
3
- import f from "../layers/TiedEmbedding.js";
4
- import g from "../layers/RoPECache.js";
5
- import b from "../layers/RMSNorm.js";
6
- import { t as c, k as p } from "../index-BzFyqcy-.js";
7
- import w from "./model.js";
8
- import k from "../layers/PositionEmbedding.js";
9
- class R extends w {
1
+ import { defaultConfig as g } from "./config.js";
2
+ import b from "../layers/TransformerBlock.js";
3
+ import k from "../layers/TiedEmbedding.js";
4
+ import w from "../layers/RoPECache.js";
5
+ import E from "../layers/RMSNorm.js";
6
+ import { t as l, k as u } from "../index-ZyQhjEPo.js";
7
+ import C from "./model.js";
8
+ import P from "../layers/PositionEmbedding.js";
9
+ import { packingSupported as _ } from "../utilities/packed.js";
10
+ import { p as y, u as M } from "../pack16-CFUqumar.js";
11
+ class I extends C {
10
12
  wte;
11
13
  // Token embeddings
12
14
  wpe;
@@ -17,56 +19,63 @@ class R extends w {
17
19
  // Final layer norm
18
20
  ropeCache;
19
21
  constructor(e = {}) {
20
- super({ ...u, ...e }), this.wte = new f(this.config, "token_embedding", this), this.config.useRope === !1 ? this.wpe = new k(this.config, "positional_embedding", this) : this.ropeCache = new g(this.config), this.blocks = [];
22
+ super({ ...g, ...e }), this.wte = new k(this.config, "token_embedding", this), this.config.useRope === !1 ? this.wpe = new P(this.config, "positional_embedding", this) : this.ropeCache = new w(this.config), this.blocks = [];
21
23
  for (let i = 0; i < this.config.nLayer; i++)
22
- this.blocks.push(new a(i, this.config, this));
23
- this.lnF = new b(this.config, "final_rms_norm", this);
24
+ this.blocks.push(new b(i, this.config, this));
25
+ this.lnF = new E(this.config, "final_rms_norm", this);
24
26
  }
25
27
  getClassName() {
26
28
  return "GenAI_NanoGPT_v1";
27
29
  }
28
30
  inputPhase(e, i) {
29
- return c(() => {
30
- const s = this.wte.embed(e);
31
+ return l(() => {
32
+ const n = this.wte.embed(e);
31
33
  if (this.config.useRope === !1) {
32
- const o = this.wpe.call(i, s);
34
+ const o = this.wpe.call(i, n);
33
35
  if (Array.isArray(o))
34
36
  throw new Error("PositionEmbedding output should not be an array");
35
37
  return o;
36
38
  }
37
- return s;
39
+ return n;
38
40
  });
39
41
  }
40
- forward(e, i, s) {
41
- return this.validateInput(i), e.ropeCache = this.ropeCache, e.outputEmbeddings && (e.embeddings = []), c(() => {
42
+ forward(e, i, n) {
43
+ return this.validateInput(i), e.ropeCache = this.ropeCache, e.outputEmbeddings && (e.embeddings = []), l(() => {
42
44
  this.startMemory();
43
45
  let o = this.inputPhase(i, e);
44
46
  if (e.cache && e.cache.length !== this.blocks.length)
45
47
  throw console.error("Cache", e.cache), new Error(
46
48
  `Cache length ${e.cache.length} does not match number of blocks ${this.blocks.length}`
47
49
  );
48
- for (let t = 0; t < this.blocks.length; t++) {
49
- const r = this.blocks[t], d = Math.random() * 1e9, l = {
50
+ const t = e.mixedPrecision === !0 && _();
51
+ let s = t ? y(o) : o;
52
+ t && o !== s && o.dispose();
53
+ for (let r = 0; r < this.blocks.length; r++) {
54
+ const d = this.blocks[r], a = Math.random() * 1e9, m = {
50
55
  ...e,
51
- seed: d,
52
- pastKV: e.cache ? e.cache[t] : void 0
53
- }, m = e.checkpointing && e.training ? r.callCheckpoint(l, o) : r.call(l, o);
54
- e.outputEmbeddings ? (p(o), e.embeddings.push({ name: `block_output_${t}`, tensor: o })) : o.dispose(), o = m;
56
+ seed: a,
57
+ pastKV: e.cache ? e.cache[r] : void 0,
58
+ mixedPrecision: t
59
+ }, f = e.checkpointing && e.training ? d.callCheckpoint(m, s) : d.call(m, s);
60
+ e.outputEmbeddings ? (u(s), e.embeddings.push({ name: `block_output_${r}`, tensor: s })) : s.dispose(), s = f;
55
61
  }
56
- o = this.lnF.call(e, o);
57
- const n = this.wte.project(o);
58
- e.outputEmbeddings ? (p(o), e.embeddings.push({ name: "final_norm_output", tensor: o })) : o.dispose();
59
- let h;
60
- return s && (h = this.calculateLoss(n, s)), this.endMemory("Forward"), h ? [n, h] : [n];
62
+ if (o = this.lnF.call({ ...e, mixedPrecision: t }, s), s.dispose(), e.skipLogits)
63
+ return this.endMemory("Forward"), [o];
64
+ const c = this.wte.project(o);
65
+ e.outputEmbeddings ? (u(o), e.embeddings.push({ name: "final_norm_output", tensor: o })) : o.dispose();
66
+ const h = t ? M(c) : c;
67
+ t && c !== h && c.dispose();
68
+ let p;
69
+ return n && (p = this.calculateLoss(h, n)), this.endMemory("Forward"), p ? [h, p] : [h];
61
70
  });
62
71
  }
63
72
  project(e) {
64
- return c(() => this.wte.project(e));
73
+ return l(() => this.wte.project(e));
65
74
  }
66
75
  dispose() {
67
76
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.blocks.forEach((e) => e.dispose()), this.lnF.dispose();
68
77
  }
69
78
  }
70
79
  export {
71
- R as default
80
+ I as default
72
81
  };
@@ -5,6 +5,7 @@ export interface ModelForwardAttributes extends ForwardAttributes {
5
5
  cache?: KVCache[];
6
6
  attentionScores?: AttentionScores;
7
7
  seed?: number;
8
+ skipLogits?: boolean;
8
9
  }
9
10
  interface TrainingState {
10
11
  steps: number;
@@ -13,6 +14,7 @@ interface TrainingState {
13
14
  loss: number;
14
15
  }
15
16
  export default abstract class Model<T extends ModelForwardAttributes> extends BaseLayer<T> {
17
+ lossScaling: number;
16
18
  trainingState: TrainingState | null;
17
19
  abstract getClassName(): string;
18
20
  abstract forward(attrs: T, idx: Tensor, targets?: Tensor): Tensor[];
@@ -1,53 +1,60 @@
1
- import i from "../layers/BaseLayer.js";
2
- import "../index-BzFyqcy-.js";
1
+ import m from "../layers/BaseLayer.js";
2
+ import "../utilities/packed.js";
3
+ import "../index-ZyQhjEPo.js";
3
4
  import "../ops/cpu/attentionMask.js";
4
5
  import "../ops/webgl/attentionMask.js";
5
6
  import "../ops/grads/attentionMask.js";
6
- import "../ops/cpu/qkv.js";
7
- import "../ops/webgl/qkv.js";
8
- import "../ops/grads/qkv.js";
9
- import "../random_width-CXVRloNK.js";
10
- import "../register_all_kernels-DIGpEwcf.js";
11
- import "../index-Tf7vU29b.js";
12
- import "../dataset-DlZtKmBq.js";
7
+ import "../random_width-DY6Kk2Dl.js";
8
+ import "../register_all_kernels-Bwu1PTuU.js";
9
+ import "../index-Cp39cXWe.js";
10
+ import "../dataset-0xP8GjwI.js";
13
11
  import "../ops/cpu/rope.js";
14
12
  import "../ops/webgl/rope.js";
15
- import "../ops/grads/rope.js";
13
+ import "../rope-B5UUMsPi.js";
16
14
  import "../ops/cpu/appendCache.js";
17
15
  import "../ops/webgl/appendCache.js";
18
- import "../ops/cpu/fusedSoftmax.js";
19
- import "../ops/webgl/fusedSoftmax.js";
20
- import "../ops/grads/fusedSoftmax.js";
21
- import "../ops/cpu/matMulGelu.js";
22
- import "../ops/webgl/matMulGelu.js";
23
- import "../ops/grads/matMulGelu.js";
16
+ import "../ops/grads/softmax16.js";
17
+ import "../matMul16--R5hOwDG.js";
18
+ import "../ops/webgl/matMul16.js";
19
+ import "../ops/cpu/matMul16.js";
20
+ import "../pack16-CFUqumar.js";
21
+ import "../ops/transpose16.js";
22
+ import "../ops/reshape16.js";
23
+ import "../ops/cpu/qkv.js";
24
+ import "../ops/webgl/qkv.js";
25
+ import "../ops/grads/qkv.js";
24
26
  import "../ops/cpu/normRMS.js";
25
27
  import "../ops/webgl/normRMS.js";
26
28
  import "../ops/grads/normRMS.js";
27
- import "../jszip.min-CjP2V1VV.js";
28
- import "../index-Dwqa6Zy2.js";
29
+ import "../ops/grads/add16.js";
30
+ import "../jszip.min-Bz5-11Bk.js";
31
+ import "../index-DvYrXKkX.js";
29
32
  import "../Generator.js";
30
33
  import "../ops/cpu/adamAdjust.js";
31
34
  import "../ops/webgl/adamAdjust.js";
32
35
  import "../ops/cpu/adamMoments.js";
33
36
  import "../ops/webgl/adamMoments.js";
34
- import "../papaparse.min-C8l2Kvo1.js";
35
- import { estimateParameterCount as p } from "../utilities/parameters.js";
37
+ import "../papaparse.min-C0cScC2i.js";
38
+ import { estimateParameterCount as e } from "../utilities/parameters.js";
36
39
  import "../ops/cpu/scatterSub.js";
37
40
  import "../ops/webgl/scatterSub.js";
38
41
  import "../ops/cpu/gatherSub.js";
39
42
  import "../ops/webgl/gatherSub.js";
43
+ import "../ops/cpu/matMulGelu.js";
44
+ import "../ops/webgl/matMulGelu.js";
45
+ import "../ops/grads/matMulGelu.js";
40
46
  import "../ops/cpu/gelu.js";
41
47
  import "../ops/webgl/gelu.js";
42
- import "../gelu-Bp_-935b.js";
48
+ import "../gelu-CNLFZWea.js";
43
49
  import "../ops/webgl/log.js";
44
50
  import "../checks/normRMS.js";
45
51
  import "../checks/normRMSGrad.js";
46
- import { createSoftmaxCrossEntropyWithGrad as m } from "../training/sparseCrossEntropy.js";
47
- class x extends i {
52
+ import { createSoftmaxCrossEntropyWithGrad as s } from "../training/sparseCrossEntropy.js";
53
+ class st extends m {
54
+ lossScaling = 128;
48
55
  trainingState = null;
49
56
  getNumParams() {
50
- return p(this.config);
57
+ return e(this.config);
51
58
  }
52
59
  validateInput(t) {
53
60
  if (t.shape.length !== 2)
@@ -57,14 +64,15 @@ class x extends i {
57
64
  if (t.dtype !== "int32")
58
65
  throw new Error(`Input tensor must be of type int32, got ${t.dtype}`);
59
66
  }
60
- calculateLoss(t, o) {
67
+ calculateLoss(t, i) {
61
68
  try {
62
- return m()(t, o).mean();
63
- } catch (r) {
64
- throw console.error("Error computing loss:", r), new Error(`Loss computation failed: ${r}`);
69
+ const r = s()(t, i), p = r.mean();
70
+ return r.dispose(), p;
71
+ } catch (o) {
72
+ throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
65
73
  }
66
74
  }
67
75
  }
68
76
  export {
69
- x as default
77
+ st as default
70
78
  };
@@ -1,20 +1,4 @@
1
- import { u as z } from "./gpgpu_math-CDaYiyE_.js";
2
- /**
3
- * @license
4
- * Copyright 2018 Google LLC. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 (the "License");
6
- * you may not use this file except in compliance with the License.
7
- * You may obtain a copy of the License at
8
- *
9
- * http://www.apache.org/licenses/LICENSE-2.0
10
- *
11
- * Unless required by applicable law or agreed to in writing, software
12
- * distributed under the License is distributed on an "AS IS" BASIS,
13
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- * See the License for the specific language governing permissions and
15
- * limitations under the License.
16
- * =============================================================================
17
- */
1
+ import { u as z } from "./gpgpu_math-DDVJCn6-.js";
18
2
  class g {
19
3
  constructor(e, s, v, a = !1, r = !1, c = !1, t = null, o = !1, l = !1) {
20
4
  this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = v, this.enableShapeUniforms = z(this.outputShape.length);
@@ -1,19 +1,3 @@
1
- /**
2
- * @license
3
- * Copyright 2019 Google LLC. All Rights Reserved.
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- * =============================================================================
16
- */
17
1
  function J(e, t, o) {
18
2
  const n = k(e, t, o), s = n < 0 ? -(n + 1) : n;
19
3
  e.splice(s, 0, t);
@@ -33,22 +17,6 @@ function A(e, t, o) {
33
17
  }
34
18
  return u ? n : -n - 1;
35
19
  }
36
- /**
37
- * @license
38
- * Copyright 2020 Google LLC. All Rights Reserved.
39
- * Licensed under the Apache License, Version 2.0 (the "License");
40
- * you may not use this file except in compliance with the License.
41
- * You may obtain a copy of the License at
42
- *
43
- * http://www.apache.org/licenses/LICENSE-2.0
44
- *
45
- * Unless required by applicable law or agreed to in writing, software
46
- * distributed under the License is distributed on an "AS IS" BASIS,
47
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48
- * See the License for the specific language governing permissions and
49
- * limitations under the License.
50
- * =============================================================================
51
- */
52
20
  function O(e, t, o, n, s) {
53
21
  return y(
54
22
  e,
@@ -0,0 +1,15 @@
1
+ import { E as a } from "./index-ZyQhjEPo.js";
2
+ import { d as t, m as n, s as i } from "./tensor-DdQUJZlz.js";
3
+ import { c as f } from "./complex-CSlYz-2T.js";
4
+ import { z as c } from "./zeros-2gldETuK.js";
5
+ function l(o, r = "float32") {
6
+ if (t(o), r === "complex64") {
7
+ const e = l(o, "float32"), m = c(o, "float32");
8
+ return f(e, m);
9
+ }
10
+ const s = n(i(o), r);
11
+ return a.makeTensor(s, o, r);
12
+ }
13
+ export {
14
+ l as o
15
+ };
@@ -1,4 +1,4 @@
1
- import { e as i } from "../index-BzFyqcy-.js";
1
+ import { e as i } from "../index-ZyQhjEPo.js";
2
2
  import "./cpu/adamAdjust.js";
3
3
  import "./webgl/adamAdjust.js";
4
4
  function p(r, t, e, n, m, o) {
@@ -1,2 +1,2 @@
1
1
  import { Tensor } from '@tensorflow/tfjs-core';
2
- export declare function adamMoments(moments: Tensor, gradient: Tensor, beta1: number, beta2: number): Tensor;
2
+ export declare function adamMoments(moments: Tensor, gradient: Tensor, beta1: number, beta2: number, lossScaling: number): Tensor;
@@ -1,9 +1,9 @@
1
- import { e as o } from "../index-BzFyqcy-.js";
1
+ import { e as t } from "../index-ZyQhjEPo.js";
2
2
  import "./cpu/adamMoments.js";
3
3
  import "./webgl/adamMoments.js";
4
- function p(e, n, r, m) {
5
- return o().runKernel("AdamMoments", { moments: e, gradient: n }, { beta1: r, beta2: m });
4
+ function s(e, n, r, m, o) {
5
+ return t().runKernel("AdamMoments", { moments: e, gradient: n }, { beta1: r, beta2: m, lossScaling: o });
6
6
  }
7
7
  export {
8
- p as adamMoments
8
+ s as adamMoments
9
9
  };
@@ -0,0 +1,2 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function add16(a: Tensor, b: Tensor): Tensor;
@@ -0,0 +1,9 @@
1
+ import { n as t, e as o } from "../index-ZyQhjEPo.js";
2
+ import { isPackedTensor as n } from "../utilities/packed.js";
3
+ import "./grads/add16.js";
4
+ function m(r, e) {
5
+ return !n(r) && !n(e) ? t(r, e) : o().runKernel("Add16", { a: r, b: e });
6
+ }
7
+ export {
8
+ m as add16
9
+ };
@@ -1,15 +1,22 @@
1
- import { e as a } from "../index-BzFyqcy-.js";
1
+ import { e as a } from "../index-ZyQhjEPo.js";
2
2
  import "./cpu/appendCache.js";
3
3
  import "./webgl/appendCache.js";
4
- import { c as s } from "../concat-B912vBbo.js";
5
- import { z as c } from "../zeros-Bj5rMYA7.js";
6
- function i(r, p, n, o) {
7
- if (!o) {
8
- const e = r.shape[2];
9
- return s([r, c([r.shape[0], r.shape[1], p - e, r.shape[3]])], 2);
4
+ import { isPackedTensor as c } from "../utilities/packed.js";
5
+ import { c as t } from "../concat-BHlIJeyT.js";
6
+ import { z as f } from "../zeros-2gldETuK.js";
7
+ function C(r, o, n, p) {
8
+ if (!p) {
9
+ const e = r.shape[2], s = c(r);
10
+ return t(
11
+ [
12
+ r,
13
+ f([r.shape[0], r.shape[1], o - e, r.shape[3]], s ? "int32" : r.dtype)
14
+ ],
15
+ 2
16
+ );
10
17
  }
11
- return a().runKernel("AppendCache", { cache: o, item: r }, { maxSize: p, pastLen: n });
18
+ return a().runKernel("AppendCache", { cache: p, item: r }, { maxSize: o, pastLen: n });
12
19
  }
13
20
  export {
14
- i as appendCache
21
+ C as appendCache
15
22
  };
@@ -1,10 +1,10 @@
1
- import { e as o } from "../index-BzFyqcy-.js";
1
+ import { e as r } from "../index-ZyQhjEPo.js";
2
2
  import "./cpu/attentionMask.js";
3
3
  import "./webgl/attentionMask.js";
4
4
  import "./grads/attentionMask.js";
5
- function s(t, n, e, r) {
6
- return o().runKernel("AttentionMask", { q: t, k: n }, { divisor: e, pastLen: r || 0 });
5
+ function u(t, n, e, o) {
6
+ return r().runKernel("AttentionMask", { q: t, k: n }, { divisor: e, pastLen: o || 0 });
7
7
  }
8
8
  export {
9
- s as attentionMask
9
+ u as attentionMask
10
10
  };
@@ -0,0 +1,2 @@
1
+ import { Rank, Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function concat16<R extends Rank = Rank>(x: Tensor<R>[], axis?: number): Tensor<R>;
@@ -0,0 +1,9 @@
1
+ import { isPackedTensor as o } from "../utilities/packed.js";
2
+ import { e } from "../index-ZyQhjEPo.js";
3
+ import { c } from "../concat-BHlIJeyT.js";
4
+ function p(r, n) {
5
+ return o(r[0]) ? e().runKernel("Concat16", r, { axis: n ?? -1 }) : c(r, n);
6
+ }
7
+ export {
8
+ p as concat16
9
+ };
@@ -1,18 +1,19 @@
1
- import { f as k, o as t, q as i, a as q, w } from "../../index-BzFyqcy-.js";
2
- function z(a) {
3
- const { moments: s, value: r } = a.inputs, { beta1: l, beta2: u, epsilon: m, learningRate: d } = a.attrs, e = s.shape.length, c = new Array(e).fill(0), n = s.shape.slice();
4
- n[e - 1] = 1;
5
- const o = c.slice();
6
- o[e - 1] = 1;
7
- const b = n.slice(), p = s.slice(c, n).squeeze([e - 1]), M = s.slice(o, b).squeeze([e - 1]), f = t(p, l), g = t(M, u);
8
- return i(
9
- q(t(f, i(w(g), m ?? 1e-8)), -d),
10
- r
1
+ import { l as t, n as r, m as k, o as z } from "../../index-ZyQhjEPo.js";
2
+ import { r as A } from "../../tensor_util-DV-FP5Q3.js";
3
+ function C(o) {
4
+ const { moments: n, value: i } = o.inputs, { beta1: l, beta2: m, epsilon: u, learningRate: d } = o.attrs, e = n.shape.length, c = new Array(e).fill(0), s = n.shape.slice();
5
+ s[e - 1] = 1;
6
+ const a = c.slice();
7
+ a[e - 1] = 1;
8
+ const p = s.slice(), b = n.slice(c, s).squeeze([e - 1]), M = n.slice(a, p).squeeze([e - 1]), f = t(b, l), g = t(M, m);
9
+ return r(
10
+ k(t(f, r(z(g), u ?? 1e-8)), -d),
11
+ i
11
12
  );
12
13
  }
13
- const A = {
14
+ const h = {
14
15
  kernelName: "AdamAdjust",
15
16
  backendName: "cpu",
16
- kernelFunc: z
17
+ kernelFunc: C
17
18
  };
18
- k(A);
19
+ A(h);
@@ -1,16 +1,17 @@
1
- import { f as p } from "../../index-BzFyqcy-.js";
2
- import { s as f } from "../../stack-DFatutCx.js";
3
- function b(t) {
4
- const { moments: n, gradient: c } = t.inputs, { beta1: o, beta2: m } = t.attrs, e = n.shape.length, a = new Array(e).fill(0), s = n.shape.slice();
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { r as p } from "../../tensor_util-DV-FP5Q3.js";
3
+ import { s as b } from "../../stack-yOIAalTq.js";
4
+ function f(t) {
5
+ const { moments: n, gradient: o } = t.inputs, { beta1: c, beta2: m } = t.attrs, e = n.shape.length, a = new Array(e).fill(0), s = n.shape.slice();
5
6
  s[e - 1] = 1;
6
- const i = a.slice();
7
- i[e - 1] = 1;
8
- const r = s.slice(), l = n.slice(a, s).squeeze([e - 1]), u = n.slice(i, r).squeeze([e - 1]), M = l.mul(o).add(c.mul(1 - o)), d = u.mul(m).add(c.square().mul(1 - m));
9
- return f([M, d], -1);
7
+ const r = a.slice();
8
+ r[e - 1] = 1;
9
+ const i = s.slice(), l = n.slice(a, s).squeeze([e - 1]), u = n.slice(r, i).squeeze([e - 1]), M = l.mul(c).add(o.mul(1 - c)), d = u.mul(m).add(o.square().mul(1 - m));
10
+ return b([M, d], -1);
10
11
  }
11
12
  const g = {
12
13
  kernelName: "AdamMoments",
13
14
  backendName: "cpu",
14
- kernelFunc: b
15
+ kernelFunc: f
15
16
  };
16
17
  p(g);
@@ -1,13 +1,14 @@
1
- import { f as d } from "../../index-BzFyqcy-.js";
2
- import { c as h } from "../../concat-B912vBbo.js";
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { r as d } from "../../tensor_util-DV-FP5Q3.js";
3
+ import { c as h } from "../../concat-BHlIJeyT.js";
3
4
  function u(p) {
4
- const { cache: n, item: s } = p.inputs, { maxSize: i, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], a = n.shape[3], e = s.shape[2];
5
- if (c + e <= i) {
6
- const l = n.slice([0, 0, 0, 0], [t, o, c, a]), m = n.slice([0, 0, c + e, 0], [t, o, i - c - e, a]), r = e < e ? s.slice([0, 0, 0, 0], [t, o, e, a]) : s, k = h([l, r, m], 2);
7
- return l.dispose(), m.dispose(), r !== s && r.dispose(), k;
5
+ const { cache: n, item: s } = p.inputs, { maxSize: a, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], r = n.shape[3], e = s.shape[2];
6
+ if (c + e <= a) {
7
+ const m = n.slice([0, 0, 0, 0], [t, o, c, r]), f = n.slice([0, 0, c + e, 0], [t, o, a - c - e, r]), i = e < e ? s.slice([0, 0, 0, 0], [t, o, e, r]) : s, k = h([m, i, f], 2);
8
+ return m.dispose(), f.dispose(), i !== s && i.dispose(), k;
8
9
  }
9
- const f = n.slice([0, 0, e, 0], [t, o, i - e, a]), C = h([f, s], 2);
10
- return f.dispose(), C;
10
+ const l = n.slice([0, 0, e, 0], [t, o, a - e, r]), C = h([l, s], 2);
11
+ return l.dispose(), C;
11
12
  }
12
13
  const w = {
13
14
  kernelName: "AppendCache",
@@ -1,21 +1,22 @@
1
- import { f as a, h as p, b as u } from "../../index-BzFyqcy-.js";
2
- import { l as N, w as b } from "../../ops-LuCMAnmM.js";
3
- import { o as A } from "../../ones-tIJeHlq-.js";
4
- import { z as I } from "../../zeros-Bj5rMYA7.js";
5
- import { m as g } from "../../mat_mul-DzjTFx-u.js";
6
- function o(n) {
7
- const { q: s, k: e } = n.inputs, { divisor: r } = n.attrs, c = s.shape[2], t = e.shape[2], m = N.bandPart(A([t, t]), -1, 0).cast("bool"), l = I([t, t]), i = p([t, t], Number.NEGATIVE_INFINITY), f = b(m, l, i), k = g(s, e, !1, !0).mul(u(r)), d = f.slice([0, 0], [c, t]).expandDims(0).expandDims(0);
8
- return k.add(d);
1
+ import { i as d, b as u } from "../../index-ZyQhjEPo.js";
2
+ import { r as o } from "../../tensor_util-DV-FP5Q3.js";
3
+ import { l as N, w as b } from "../../ops-CNI3TwqM.js";
4
+ import { o as A } from "../../ones-CAMiP4I2.js";
5
+ import { z as I } from "../../zeros-2gldETuK.js";
6
+ import { m as g } from "../../mat_mul-DeAh4uTH.js";
7
+ function a(n) {
8
+ const { q: s, k: e } = n.inputs, { divisor: r } = n.attrs, c = s.shape[2], t = e.shape[2], m = N.bandPart(A([t, t]), -1, 0).cast("bool"), i = I([t, t]), l = d([t, t], Number.NEGATIVE_INFINITY), f = b(m, i, l), k = g(s, e, !1, !0).mul(u(r)), p = f.slice([0, 0], [c, t]).expandDims(0).expandDims(0);
9
+ return k.add(p);
9
10
  }
10
- const h = {
11
+ const w = {
11
12
  kernelName: "AttentionMask",
12
13
  backendName: "cpu",
13
- kernelFunc: o
14
+ kernelFunc: a
14
15
  };
15
- a(h);
16
- const w = {
16
+ o(w);
17
+ const M = {
17
18
  kernelName: "AttentionMask",
18
19
  backendName: "tensorflow",
19
- kernelFunc: o
20
+ kernelFunc: a
20
21
  };
21
- a(w);
22
+ o(M);
@@ -1,29 +1,30 @@
1
- import { f as e } from "../../index-BzFyqcy-.js";
2
- import { s as f } from "../../softmax-D7Jj3p_P.js";
3
- function n(t) {
4
- const { inputs: s, attrs: a } = t, { logits: o } = s, { dim: i, dropoutRate: r } = a;
5
- if (!o)
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { r as e } from "../../tensor_util-DV-FP5Q3.js";
3
+ import { s as m } from "../../softmax-ZHVebtR1.js";
4
+ function o(t) {
5
+ const { inputs: s, attrs: a } = t, { logits: n } = s, { dim: i, dropoutRate: r } = a;
6
+ if (!n)
6
7
  throw new Error("Error in softmax: input logits is null");
7
- return r !== void 0 && r > 0 && console.warn("Dropout in fusedSoftmax not implemented for CPU backend, skipping dropout."), f(o, i);
8
+ return r !== void 0 && r > 0 && console.warn("Dropout in fusedSoftmax not implemented for CPU backend, skipping dropout."), m(n, i);
8
9
  }
9
- const m = {
10
+ const f = {
10
11
  kernelName: "FusedSoftmax",
11
12
  backendName: "cpu",
12
- kernelFunc: n
13
+ kernelFunc: o
13
14
  };
14
- e(m);
15
+ e(f);
15
16
  const u = {
16
17
  kernelName: "FusedSoftmax",
17
18
  backendName: "tensorflow",
18
- kernelFunc: n
19
+ kernelFunc: o
19
20
  };
20
21
  e(u);
21
22
  const l = {
22
23
  kernelName: "FusedSoftmax",
23
24
  backendName: "webgpu",
24
- kernelFunc: n
25
+ kernelFunc: o
25
26
  };
26
27
  e(l);
27
28
  export {
28
- n as softmaxCPU
29
+ o as softmaxCPU
29
30
  };
@@ -1,34 +1,19 @@
1
- import { E as u, F as c, G as g, a8 as h, f as m, c as p } from "../../index-BzFyqcy-.js";
2
- import { r as f } from "../../range-CWcz7xFA.js";
3
- import { s as l } from "../../stack-DFatutCx.js";
4
- /**
5
- * @license
6
- * Copyright 2018 Google LLC. All Rights Reserved.
7
- * Licensed under the Apache License, Version 2.0 (the "License");
8
- * you may not use this file except in compliance with the License.
9
- * You may obtain a copy of the License at
10
- *
11
- * http://www.apache.org/licenses/LICENSE-2.0
12
- *
13
- * Unless required by applicable law or agreed to in writing, software
14
- * distributed under the License is distributed on an "AS IS" BASIS,
15
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
- * See the License for the specific language governing permissions and
17
- * limitations under the License.
18
- * =============================================================================
19
- */
1
+ import { A as u, B as c, E as m, c as g } from "../../index-ZyQhjEPo.js";
2
+ import { k as p, r as h } from "../../tensor_util-DV-FP5Q3.js";
3
+ import { r as f } from "../../range-BMS52eQi.js";
4
+ import { s as l } from "../../stack-yOIAalTq.js";
20
5
  function N(e, t) {
21
- const n = c(t, "indices", "gatherND", "int32"), s = { params: c(e, "x", "gatherND", "string_or_numeric"), indices: n };
22
- return g.runKernel(h, s);
6
+ const n = c(t, "indices", "gatherND", "int32"), r = { params: c(e, "x", "gatherND", "string_or_numeric"), indices: n };
7
+ return m.runKernel(p, r);
23
8
  }
24
9
  const b = /* @__PURE__ */ u({ gatherND_: N });
25
10
  function d(e) {
26
- const { values: t, labels: n, logits: r } = e.inputs, s = n.shape[0], a = f(0, s, 1, "int32"), i = l([a, n], 1), o = b(r, i);
27
- return p(t, o);
11
+ const { values: t, labels: n, logits: s } = e.inputs, r = n.shape[0], o = f(0, r, 1, "int32"), i = l([o, n], 1), a = b(s, i);
12
+ return g(t, a);
28
13
  }
29
14
  const k = {
30
15
  kernelName: "EfficientGatherSub",
31
16
  backendName: "cpu",
32
17
  kernelFunc: d
33
18
  };
34
- m(k);
19
+ h(k);