@genai-fi/nanogpt 0.20.0 → 0.20.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (433) hide show
  1. package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
  2. package/dist/DatasetBuilder-DgURD85T.js +712 -0
  3. package/dist/Generator.d.ts +82 -0
  4. package/dist/Generator.js +2 -0
  5. package/dist/RealDiv-DBu0FQqT.js +362 -0
  6. package/dist/Reshape-CABOPB9d.js +94 -0
  7. package/dist/Reshape-DqO3r8BC.js +17 -0
  8. package/dist/TeachableLLM.d.ts +70 -0
  9. package/dist/TeachableLLM.js +2 -0
  10. package/dist/Trainer.d.ts +43 -0
  11. package/dist/Trainer.js +2 -0
  12. package/dist/backend.d.ts +2 -0
  13. package/dist/backend.js +13 -0
  14. package/dist/backend_util-Cg-roD1p.js +399 -0
  15. package/dist/binary_op_util-CrYk9LXL.js +103 -0
  16. package/dist/checks/appendCache.d.ts +1 -0
  17. package/dist/checks/appendCache.js +55 -0
  18. package/dist/checks/attentionMask.d.ts +1 -0
  19. package/dist/checks/attentionMask.js +56 -0
  20. package/dist/checks/check.d.ts +9 -0
  21. package/dist/checks/check.js +32 -0
  22. package/dist/checks/gelu.d.ts +1 -0
  23. package/dist/checks/gelu.js +46 -0
  24. package/dist/checks/index.d.ts +26 -0
  25. package/dist/checks/index.js +28 -0
  26. package/dist/checks/matMulGelu.d.ts +1 -0
  27. package/dist/checks/matMulGelu.js +84 -0
  28. package/dist/checks/normRMS.d.ts +1 -0
  29. package/dist/checks/normRMS.js +28 -0
  30. package/dist/checks/normRMSGrad.d.ts +1 -0
  31. package/dist/checks/normRMSGrad.js +22 -0
  32. package/dist/checks/packUnpack.d.ts +1 -0
  33. package/dist/checks/packUnpack.js +46 -0
  34. package/dist/checks/qkv.d.ts +1 -0
  35. package/dist/checks/qkv.js +34 -0
  36. package/dist/checks/rope.d.ts +1 -0
  37. package/dist/checks/rope.js +30 -0
  38. package/dist/checks/weights.d.ts +14 -0
  39. package/dist/checks/weights.js +27 -0
  40. package/dist/chunk-BPntVaq0.js +23 -0
  41. package/dist/complex_util-CkazZsaH.js +60 -0
  42. package/dist/concat_util-CWDZCBlA.js +19 -0
  43. package/dist/data/docx.d.ts +2 -0
  44. package/dist/data/docx.js +3046 -0
  45. package/dist/data/pdf.d.ts +2 -0
  46. package/dist/data/pdf.js +17 -0
  47. package/dist/data/textLoader.d.ts +7 -0
  48. package/dist/data/textLoader.js +613 -0
  49. package/dist/dist-BewPQWjc.js +7572 -0
  50. package/dist/dist-DVmq73nz.js +8775 -0
  51. package/dist/dist-DXwIvKxl.js +896 -0
  52. package/dist/dist-VEU5mfO0.js +7545 -0
  53. package/dist/gelu-Bf1HW1RY.js +27 -0
  54. package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
  55. package/dist/inference/types.d.ts +16 -0
  56. package/dist/inference/types.js +0 -0
  57. package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
  58. package/dist/layers/BaseLayer.d.ts +44 -0
  59. package/dist/layers/BaseLayer.js +76 -0
  60. package/dist/layers/CausalSelfAttention.d.ts +39 -0
  61. package/dist/layers/CausalSelfAttention.js +99 -0
  62. package/dist/layers/LoRA.d.ts +14 -0
  63. package/dist/layers/LoRA.js +48 -0
  64. package/dist/layers/MLP.d.ts +17 -0
  65. package/dist/layers/MLP.js +34 -0
  66. package/dist/layers/PositionEmbedding.d.ts +8 -0
  67. package/dist/layers/PositionEmbedding.js +27 -0
  68. package/dist/layers/RMSNorm.d.ts +12 -0
  69. package/dist/layers/RMSNorm.js +20 -0
  70. package/dist/layers/RoPECache.d.ts +18 -0
  71. package/dist/layers/RoPECache.js +337 -0
  72. package/dist/layers/TiedEmbedding.d.ts +13 -0
  73. package/dist/layers/TiedEmbedding.js +32 -0
  74. package/dist/layers/TransformerBlock.d.ts +27 -0
  75. package/dist/layers/TransformerBlock.js +51 -0
  76. package/dist/layers/WeightStore.d.ts +20 -0
  77. package/dist/layers/WeightStore.js +69 -0
  78. package/dist/loader/load.d.ts +6 -0
  79. package/dist/loader/load.js +2 -0
  80. package/dist/loader/loadHF.d.ts +8 -0
  81. package/dist/loader/loadHF.js +2 -0
  82. package/dist/loader/loadTransformers.d.ts +4 -0
  83. package/dist/loader/loadTransformers.js +2 -0
  84. package/dist/loader/loadZipMeta.d.ts +3 -0
  85. package/dist/loader/loadZipMeta.js +16 -0
  86. package/dist/loader/newZipLoad.d.ts +3 -0
  87. package/dist/loader/newZipLoad.js +2 -0
  88. package/dist/loader/oldZipLoad.d.ts +9 -0
  89. package/dist/loader/oldZipLoad.js +2 -0
  90. package/dist/loader/save.d.ts +16 -0
  91. package/dist/loader/save.js +2 -0
  92. package/dist/loader/types.d.ts +68 -0
  93. package/dist/loader/types.js +0 -0
  94. package/dist/main-D5CbfCiV.js +13500 -0
  95. package/dist/main.d.ts +50 -0
  96. package/dist/main.js +16 -0
  97. package/dist/matMul16-BNfZSnNM.js +81 -0
  98. package/dist/matMulGelu-CPTntosE.js +162 -0
  99. package/dist/models/NanoGPTV1.d.ts +16 -0
  100. package/dist/models/NanoGPTV1.js +2 -0
  101. package/dist/models/NanoGPTV2.d.ts +16 -0
  102. package/dist/models/NanoGPTV2.js +2 -0
  103. package/dist/models/config.d.ts +27 -0
  104. package/dist/models/config.js +37 -0
  105. package/dist/models/factory.d.ts +3 -0
  106. package/dist/models/factory.js +2 -0
  107. package/dist/models/model.d.ts +44 -0
  108. package/dist/models/model.js +2 -0
  109. package/dist/ops/adamAdjust.d.ts +2 -0
  110. package/dist/ops/adamAdjust.js +18 -0
  111. package/dist/ops/adamMoments.d.ts +2 -0
  112. package/dist/ops/adamMoments.js +16 -0
  113. package/dist/ops/add16.d.ts +2 -0
  114. package/dist/ops/add16.js +12 -0
  115. package/dist/ops/appendCache.d.ts +2 -0
  116. package/dist/ops/appendCache.js +25 -0
  117. package/dist/ops/attentionMask.d.ts +2 -0
  118. package/dist/ops/attentionMask.js +16 -0
  119. package/dist/ops/concat16.d.ts +2 -0
  120. package/dist/ops/concat16.js +8 -0
  121. package/dist/ops/cpu/adamAdjust.d.ts +1 -0
  122. package/dist/ops/cpu/adamAdjust.js +16 -0
  123. package/dist/ops/cpu/adamMoments.d.ts +1 -0
  124. package/dist/ops/cpu/adamMoments.js +16 -0
  125. package/dist/ops/cpu/appendCache.d.ts +1 -0
  126. package/dist/ops/cpu/appendCache.js +65 -0
  127. package/dist/ops/cpu/attentionMask.d.ts +1 -0
  128. package/dist/ops/cpu/attentionMask.js +16 -0
  129. package/dist/ops/cpu/fusedSoftmax.d.ts +9 -0
  130. package/dist/ops/cpu/fusedSoftmax.js +22 -0
  131. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  132. package/dist/ops/cpu/gatherSub.js +12 -0
  133. package/dist/ops/cpu/gelu.d.ts +1 -0
  134. package/dist/ops/cpu/gelu.js +36 -0
  135. package/dist/ops/cpu/matMul16.d.ts +1 -0
  136. package/dist/ops/cpu/matMul16.js +14 -0
  137. package/dist/ops/cpu/matMulGelu.d.ts +1 -0
  138. package/dist/ops/cpu/matMulGelu.js +41 -0
  139. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  140. package/dist/ops/cpu/matMulMul.js +20 -0
  141. package/dist/ops/cpu/mulDropout.d.ts +1 -0
  142. package/dist/ops/cpu/mulDropout.js +20 -0
  143. package/dist/ops/cpu/normRMS.d.ts +1 -0
  144. package/dist/ops/cpu/normRMS.js +35 -0
  145. package/dist/ops/cpu/qkv.d.ts +5 -0
  146. package/dist/ops/cpu/qkv.js +73 -0
  147. package/dist/ops/cpu/rope.d.ts +6 -0
  148. package/dist/ops/cpu/rope.js +81 -0
  149. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  150. package/dist/ops/cpu/scatterSub.js +12 -0
  151. package/dist/ops/dot16.d.ts +2 -0
  152. package/dist/ops/dot16.js +29 -0
  153. package/dist/ops/dropout.d.ts +2 -0
  154. package/dist/ops/dropout.js +11 -0
  155. package/dist/ops/dropout16.d.ts +2 -0
  156. package/dist/ops/dropout16.js +22 -0
  157. package/dist/ops/gatherSub.d.ts +2 -0
  158. package/dist/ops/gatherSub.js +13 -0
  159. package/dist/ops/gelu.d.ts +3 -0
  160. package/dist/ops/gelu.js +2 -0
  161. package/dist/ops/globalNorm.d.ts +2 -0
  162. package/dist/ops/globalNorm.js +19 -0
  163. package/dist/ops/grads/add16.d.ts +1 -0
  164. package/dist/ops/grads/add16.js +27 -0
  165. package/dist/ops/grads/attentionMask.d.ts +1 -0
  166. package/dist/ops/grads/attentionMask.js +26 -0
  167. package/dist/ops/grads/dropout16.d.ts +1 -0
  168. package/dist/ops/grads/dropout16.js +1 -0
  169. package/dist/ops/grads/gelu.d.ts +2 -0
  170. package/dist/ops/grads/gelu.js +2 -0
  171. package/dist/ops/grads/matMul16.d.ts +2 -0
  172. package/dist/ops/grads/matMul16.js +2 -0
  173. package/dist/ops/grads/matMulGelu.d.ts +1 -0
  174. package/dist/ops/grads/matMulGelu.js +22 -0
  175. package/dist/ops/grads/mul16.d.ts +1 -0
  176. package/dist/ops/grads/mul16.js +1 -0
  177. package/dist/ops/grads/normRMS.d.ts +3 -0
  178. package/dist/ops/grads/normRMS.js +37 -0
  179. package/dist/ops/grads/pack16.d.ts +2 -0
  180. package/dist/ops/grads/pack16.js +2 -0
  181. package/dist/ops/grads/qkv.d.ts +3 -0
  182. package/dist/ops/grads/qkv.js +46 -0
  183. package/dist/ops/grads/rope.d.ts +2 -0
  184. package/dist/ops/grads/rope.js +2 -0
  185. package/dist/ops/grads/softmax16.d.ts +2 -0
  186. package/dist/ops/grads/softmax16.js +23 -0
  187. package/dist/ops/grads/unpack16.d.ts +2 -0
  188. package/dist/ops/grads/unpack16.js +2 -0
  189. package/dist/ops/grads/utils.d.ts +4 -0
  190. package/dist/ops/grads/utils.js +12 -0
  191. package/dist/ops/log.d.ts +0 -0
  192. package/dist/ops/log.js +1 -0
  193. package/dist/ops/matMul16.d.ts +15 -0
  194. package/dist/ops/matMul16.js +2 -0
  195. package/dist/ops/matMulGelu.d.ts +3 -0
  196. package/dist/ops/matMulGelu.js +20 -0
  197. package/dist/ops/matMulMul.d.ts +2 -0
  198. package/dist/ops/matMulMul.js +16 -0
  199. package/dist/ops/mul16.d.ts +2 -0
  200. package/dist/ops/mul16.js +43 -0
  201. package/dist/ops/mulDrop.d.ts +2 -0
  202. package/dist/ops/mulDrop.js +15 -0
  203. package/dist/ops/normRMS.d.ts +2 -0
  204. package/dist/ops/normRMS.js +22 -0
  205. package/dist/ops/pack16.d.ts +2 -0
  206. package/dist/ops/pack16.js +2 -0
  207. package/dist/ops/qkv.d.ts +2 -0
  208. package/dist/ops/qkv.js +16 -0
  209. package/dist/ops/reshape16.d.ts +2 -0
  210. package/dist/ops/reshape16.js +33 -0
  211. package/dist/ops/rope.d.ts +3 -0
  212. package/dist/ops/rope.js +2 -0
  213. package/dist/ops/scatterSub.d.ts +2 -0
  214. package/dist/ops/scatterSub.js +13 -0
  215. package/dist/ops/slice16.d.ts +2 -0
  216. package/dist/ops/slice16.js +11 -0
  217. package/dist/ops/softmax16.d.ts +2 -0
  218. package/dist/ops/softmax16.js +9 -0
  219. package/dist/ops/sub16.d.ts +2 -0
  220. package/dist/ops/sub16.js +11 -0
  221. package/dist/ops/sum16.d.ts +2 -0
  222. package/dist/ops/sum16.js +13 -0
  223. package/dist/ops/transpose16.d.ts +3 -0
  224. package/dist/ops/transpose16.js +32 -0
  225. package/dist/ops/unpack16.d.ts +2 -0
  226. package/dist/ops/unpack16.js +2 -0
  227. package/dist/ops/webgl/adamAdjust.d.ts +1 -0
  228. package/dist/ops/webgl/adamAdjust.js +82 -0
  229. package/dist/ops/webgl/adamMoments.d.ts +1 -0
  230. package/dist/ops/webgl/adamMoments.js +44 -0
  231. package/dist/ops/webgl/appendCache.d.ts +1 -0
  232. package/dist/ops/webgl/appendCache.js +53 -0
  233. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  234. package/dist/ops/webgl/attentionMask.js +64 -0
  235. package/dist/ops/webgl/dropout16.d.ts +1 -0
  236. package/dist/ops/webgl/dropout16.js +12 -0
  237. package/dist/ops/webgl/fusedSoftmax.d.ts +11 -0
  238. package/dist/ops/webgl/fusedSoftmax.js +70 -0
  239. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  240. package/dist/ops/webgl/gatherSub.js +28 -0
  241. package/dist/ops/webgl/gelu.d.ts +2 -0
  242. package/dist/ops/webgl/gelu.js +48 -0
  243. package/dist/ops/webgl/log.d.ts +17 -0
  244. package/dist/ops/webgl/log.js +14 -0
  245. package/dist/ops/webgl/matMul16.d.ts +1 -0
  246. package/dist/ops/webgl/matMul16.js +37 -0
  247. package/dist/ops/webgl/matMulGelu.d.ts +21 -0
  248. package/dist/ops/webgl/matMulGelu.js +2 -0
  249. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  250. package/dist/ops/webgl/matMulMul.js +24 -0
  251. package/dist/ops/webgl/mulDropout.d.ts +1 -0
  252. package/dist/ops/webgl/mulDropout.js +32 -0
  253. package/dist/ops/webgl/normRMS.d.ts +1 -0
  254. package/dist/ops/webgl/normRMS.js +114 -0
  255. package/dist/ops/webgl/qkv.d.ts +1 -0
  256. package/dist/ops/webgl/qkv.js +54 -0
  257. package/dist/ops/webgl/rope.d.ts +1 -0
  258. package/dist/ops/webgl/rope.js +72 -0
  259. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  260. package/dist/ops/webgl/scatterSub.js +28 -0
  261. package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
  262. package/dist/ops/webgpu/adamAdjust.js +77 -0
  263. package/dist/ops/webgpu/adamMoments.d.ts +1 -0
  264. package/dist/ops/webgpu/adamMoments.js +76 -0
  265. package/dist/ops/webgpu/add16.d.ts +1 -0
  266. package/dist/ops/webgpu/add16.js +14 -0
  267. package/dist/ops/webgpu/appendCache.d.ts +1 -0
  268. package/dist/ops/webgpu/appendCache.js +130 -0
  269. package/dist/ops/webgpu/attentionMask.d.ts +1 -0
  270. package/dist/ops/webgpu/attentionMask.js +42 -0
  271. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  272. package/dist/ops/webgpu/attentionMask32_program.js +62 -0
  273. package/dist/ops/webgpu/clipScale.d.ts +1 -0
  274. package/dist/ops/webgpu/clipScale.js +45 -0
  275. package/dist/ops/webgpu/concat16.d.ts +19 -0
  276. package/dist/ops/webgpu/concat16.js +111 -0
  277. package/dist/ops/webgpu/dropout16.d.ts +1 -0
  278. package/dist/ops/webgpu/dropout16.js +59 -0
  279. package/dist/ops/webgpu/gatherSub.d.ts +1 -0
  280. package/dist/ops/webgpu/gatherSub.js +52 -0
  281. package/dist/ops/webgpu/gelu.d.ts +14 -0
  282. package/dist/ops/webgpu/gelu.js +147 -0
  283. package/dist/ops/webgpu/index.d.ts +0 -0
  284. package/dist/ops/webgpu/index.js +26 -0
  285. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  286. package/dist/ops/webgpu/matMul16.js +70 -0
  287. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  288. package/dist/ops/webgpu/matMul16_program.js +303 -0
  289. package/dist/ops/webgpu/mul16.d.ts +1 -0
  290. package/dist/ops/webgpu/mul16.js +14 -0
  291. package/dist/ops/webgpu/norm2.d.ts +1 -0
  292. package/dist/ops/webgpu/norm2.js +46 -0
  293. package/dist/ops/webgpu/normRMS.d.ts +1 -0
  294. package/dist/ops/webgpu/normRMS.js +26 -0
  295. package/dist/ops/webgpu/normRMS16_program.d.ts +10 -0
  296. package/dist/ops/webgpu/normRMS16_program.js +28 -0
  297. package/dist/ops/webgpu/normRMS32_program.d.ts +10 -0
  298. package/dist/ops/webgpu/normRMS32_program.js +28 -0
  299. package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
  300. package/dist/ops/webgpu/normRMSGrad.js +225 -0
  301. package/dist/ops/webgpu/pack16.d.ts +1 -0
  302. package/dist/ops/webgpu/pack16.js +21 -0
  303. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  304. package/dist/ops/webgpu/pack16_program.js +93 -0
  305. package/dist/ops/webgpu/qkv.d.ts +1 -0
  306. package/dist/ops/webgpu/qkv.js +64 -0
  307. package/dist/ops/webgpu/rope.d.ts +1 -0
  308. package/dist/ops/webgpu/rope.js +163 -0
  309. package/dist/ops/webgpu/scatterSub.d.ts +1 -0
  310. package/dist/ops/webgpu/scatterSub.js +53 -0
  311. package/dist/ops/webgpu/slice16.d.ts +7 -0
  312. package/dist/ops/webgpu/slice16.js +74 -0
  313. package/dist/ops/webgpu/softmax16.d.ts +17 -0
  314. package/dist/ops/webgpu/softmax16.js +18 -0
  315. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  316. package/dist/ops/webgpu/softmax16_program.js +89 -0
  317. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  318. package/dist/ops/webgpu/softmax16_subgroup_program.js +70 -0
  319. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  320. package/dist/ops/webgpu/softmax16grad.js +31 -0
  321. package/dist/ops/webgpu/sub16.d.ts +1 -0
  322. package/dist/ops/webgpu/sub16.js +14 -0
  323. package/dist/ops/webgpu/sum16.d.ts +1 -0
  324. package/dist/ops/webgpu/sum16.js +29 -0
  325. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  326. package/dist/ops/webgpu/transpose16.js +37 -0
  327. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  328. package/dist/ops/webgpu/transpose16_program.js +51 -0
  329. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  330. package/dist/ops/webgpu/transpose16_shared_program.js +79 -0
  331. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  332. package/dist/ops/webgpu/unpack16.js +60 -0
  333. package/dist/ops/webgpu/utils/binary_op.d.ts +35 -0
  334. package/dist/ops/webgpu/utils/binary_op.js +141 -0
  335. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  336. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  337. package/dist/ops/webgpu/utils/reductions.d.ts +43 -0
  338. package/dist/ops/webgpu/utils/reductions.js +263 -0
  339. package/dist/pack16-Ck-spx_F.js +39 -0
  340. package/dist/patches/webgpu_backend.d.ts +18 -0
  341. package/dist/patches/webgpu_backend.js +43 -0
  342. package/dist/patches/webgpu_base.d.ts +21 -0
  343. package/dist/patches/webgpu_base.js +22 -0
  344. package/dist/patches/webgpu_program.d.ts +36 -0
  345. package/dist/patches/webgpu_program.js +293 -0
  346. package/dist/pdf-UoDqCYzz.js +16726 -0
  347. package/dist/picomatch-3tUnMMbd.js +1063 -0
  348. package/dist/rope-CbeGlsV8.js +25 -0
  349. package/dist/selu_util-zkAx5doH.js +24 -0
  350. package/dist/shared-D1coEFea.js +1314 -0
  351. package/dist/shared-DOgWaqvL.js +5 -0
  352. package/dist/slice_util-Dgb3ANWI.js +208 -0
  353. package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
  354. package/dist/tokeniser/BaseTokeniser.d.ts +33 -0
  355. package/dist/tokeniser/BaseTokeniser.js +2 -0
  356. package/dist/tokeniser/CharTokeniser.d.ts +24 -0
  357. package/dist/tokeniser/CharTokeniser.js +92 -0
  358. package/dist/tokeniser/bpe.d.ts +28 -0
  359. package/dist/tokeniser/bpe.js +170 -0
  360. package/dist/tokeniser/messages.d.ts +61 -0
  361. package/dist/tokeniser/messages.js +0 -0
  362. package/dist/tokeniser/type.d.ts +34 -0
  363. package/dist/tokeniser/type.js +0 -0
  364. package/dist/training/AdamW.d.ts +36 -0
  365. package/dist/training/AdamW.js +128 -0
  366. package/dist/training/BasicTrainer.d.ts +63 -0
  367. package/dist/training/BasicTrainer.js +265 -0
  368. package/dist/training/DatasetBuilder.d.ts +26 -0
  369. package/dist/training/DatasetBuilder.js +2 -0
  370. package/dist/training/Evaluator.d.ts +19 -0
  371. package/dist/training/Evaluator.js +48 -0
  372. package/dist/training/LRScheduler.d.ts +12 -0
  373. package/dist/training/LRScheduler.js +38 -0
  374. package/dist/training/PreTrainer.d.ts +11 -0
  375. package/dist/training/PreTrainer.js +22 -0
  376. package/dist/training/SFTTrainer.d.ts +12 -0
  377. package/dist/training/SFTTrainer.js +24 -0
  378. package/dist/training/loss.d.ts +3 -0
  379. package/dist/training/loss.js +19 -0
  380. package/dist/training/orthoGrad.d.ts +2 -0
  381. package/dist/training/orthoGrad.js +10 -0
  382. package/dist/training/sparseCrossEntropy.d.ts +7 -0
  383. package/dist/training/sparseCrossEntropy.js +47 -0
  384. package/dist/training/tasks/ConversationTask.d.ts +18 -0
  385. package/dist/training/tasks/ConversationTask.js +38 -0
  386. package/dist/training/tasks/PretrainingTask.d.ts +17 -0
  387. package/dist/training/tasks/PretrainingTask.js +42 -0
  388. package/dist/training/tasks/StartSentenceTask.d.ts +18 -0
  389. package/dist/training/tasks/StartSentenceTask.js +45 -0
  390. package/dist/training/tasks/Task.d.ts +22 -0
  391. package/dist/training/tasks/Task.js +55 -0
  392. package/dist/training/tasks/splitter.d.ts +5 -0
  393. package/dist/training/tasks/splitter.js +18 -0
  394. package/dist/training/types.d.ts +78 -0
  395. package/dist/training/types.js +0 -0
  396. package/dist/training/validation.d.ts +17 -0
  397. package/dist/training/validation.js +2 -0
  398. package/dist/utilities/arrayClose.d.ts +1 -0
  399. package/dist/utilities/arrayClose.js +16 -0
  400. package/dist/utilities/datasetID.d.ts +2 -0
  401. package/dist/utilities/datasetID.js +18 -0
  402. package/dist/utilities/dummy.d.ts +9 -0
  403. package/dist/utilities/dummy.js +36 -0
  404. package/dist/utilities/multinomialCPU.d.ts +2 -0
  405. package/dist/utilities/multinomialCPU.js +9 -0
  406. package/dist/utilities/naming.d.ts +4 -0
  407. package/dist/utilities/naming.js +0 -0
  408. package/dist/utilities/packed.d.ts +4 -0
  409. package/dist/utilities/packed.js +13 -0
  410. package/dist/utilities/parameters.d.ts +11 -0
  411. package/dist/utilities/parameters.js +38 -0
  412. package/dist/utilities/performance.d.ts +2 -0
  413. package/dist/utilities/performance.js +16 -0
  414. package/dist/utilities/profile.d.ts +17 -0
  415. package/dist/utilities/profile.js +33 -0
  416. package/dist/utilities/safetensors.d.ts +3 -0
  417. package/dist/utilities/safetensors.js +53 -0
  418. package/dist/utilities/sentences.d.ts +5 -0
  419. package/dist/utilities/sentences.js +32 -0
  420. package/dist/utilities/tokenParse.d.ts +1 -0
  421. package/dist/utilities/tokenParse.js +17 -0
  422. package/dist/utilities/topP.d.ts +1 -0
  423. package/dist/utilities/topP.js +12 -0
  424. package/dist/utilities/waitForModel.d.ts +2 -0
  425. package/dist/utilities/waitForModel.js +12 -0
  426. package/dist/utilities/weights.d.ts +12 -0
  427. package/dist/utilities/weights.js +40 -0
  428. package/dist/utilities/yielder.d.ts +1 -0
  429. package/dist/utilities/yielder.js +7 -0
  430. package/dist/webgpu-Dt7BMzWz.js +525 -0
  431. package/dist/webgpu_program-WOyIVMlZ.js +392 -0
  432. package/dist/webgpu_util-B_F3SShA.js +106 -0
  433. package/package.json +1 -1
@@ -0,0 +1,114 @@
1
+ import { Ii as e, ii as t, mn as n } from "../../dist-BewPQWjc.js";
2
+ //#region lib/ops/webgl/normRMS.ts
3
+ var r = class {
4
+ variableNames = ["x", "meanSquare"];
5
+ outputShape;
6
+ userCode;
7
+ constructor(e, t, n, r = !0) {
8
+ r && this.variableNames.push("gamma"), this.outputShape = [
9
+ e,
10
+ t,
11
+ n
12
+ ], this.userCode = `
13
+ void main() {
14
+ ivec3 coords = getOutputCoords();
15
+ float x = getXAtOutCoords();
16
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0);
17
+ ${r ? "float gamma = getGammaAtOutCoords();" : ""}
18
+ float invRms = inversesqrt(meanSquare + 1e-8);
19
+ float normalized = x * invRms;
20
+ float outVal = normalized ${r ? " * gamma" : ""};
21
+ setOutput(outVal);
22
+ }
23
+ `;
24
+ }
25
+ };
26
+ function i(e) {
27
+ let { x: t, gamma: n } = e.inputs, i = e.backend, a = t.shape[0], o = t.shape[1], s = t.shape[2], c = t.square().mean(-1, !0), l = new r(a, o, s, n !== void 0);
28
+ return i.runWebGLProgram(l, n ? [
29
+ t,
30
+ c,
31
+ n
32
+ ] : [t, c], "float32");
33
+ }
34
+ var a = {
35
+ kernelName: "RMSNorm",
36
+ backendName: "webgl",
37
+ kernelFunc: i
38
+ }, o = {
39
+ kernelName: "RMSNormNoGamma",
40
+ backendName: "webgl",
41
+ kernelFunc: i
42
+ };
43
+ e(a), e(o);
44
+ var s = class {
45
+ variableNames = [
46
+ "x",
47
+ "meanSquare",
48
+ "dyGamma",
49
+ "dyXMean"
50
+ ];
51
+ outputShape;
52
+ userCode;
53
+ constructor(e, t, n) {
54
+ this.outputShape = [
55
+ e,
56
+ t,
57
+ n
58
+ ], this.userCode = `
59
+ void main() {
60
+ ivec3 coords = getOutputCoords();
61
+ float x = getXAtOutCoords();
62
+ float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
63
+ float dyGamma = getDyGammaAtOutCoords();
64
+ float dyXMean = getDyXMean(coords.x, coords.y, 0) / ${n}.0;
65
+ float invRms = inversesqrt(meanSquare);
66
+ float dx = dyGamma * invRms - x * dyXMean * invRms / meanSquare;
67
+ setOutput(dx);
68
+ }
69
+ `;
70
+ }
71
+ }, c = class {
72
+ variableNames = [
73
+ "x",
74
+ "meanSquare",
75
+ "dy"
76
+ ];
77
+ outputShape;
78
+ userCode;
79
+ constructor(e, t, n) {
80
+ this.outputShape = [
81
+ e,
82
+ t,
83
+ n
84
+ ], this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getXAtOutCoords();\n float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;\n float dy = getDyAtOutCoords();\n float invRms = inversesqrt(meanSquare);\n float dGamma = dy * (x * invRms);\n setOutput(dGamma);\n }\n ";
85
+ }
86
+ };
87
+ function l(e) {
88
+ let { dy: r, x: i, gamma: a } = e.inputs, o = e.backend, l = i.shape[0], u = i.shape[1], d = i.shape[2], f = a ? r.mul(a) : r, p = f.mul(i), m = p.sum(-1, !0);
89
+ p.dispose();
90
+ let h = i.square(), g = h.mean(-1, !0);
91
+ h.dispose();
92
+ let _ = new s(l, u, d), v = o.runWebGLProgram(_, [
93
+ i,
94
+ g,
95
+ f,
96
+ m
97
+ ], "float32");
98
+ if (a && f.dispose(), m.dispose(), a) {
99
+ let e = new c(l, u, d), a = o.runWebGLProgram(e, [
100
+ i,
101
+ g,
102
+ r
103
+ ], "float32");
104
+ g.dispose();
105
+ let s = n(t().makeTensorFromTensorInfo(a), [0, 1]);
106
+ return o.disposeIntermediateTensorInfo(a), [v, s];
107
+ } else return g.dispose(), [v];
108
+ }
109
+ e({
110
+ kernelName: "RMSNormGrad",
111
+ backendName: "webgl",
112
+ kernelFunc: l
113
+ });
114
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,54 @@
1
+ import { Ii as e } from "../../dist-BewPQWjc.js";
2
+ //#region lib/ops/webgl/qkv.ts
3
+ var t = class {
4
+ variableNames = ["x", "kernel"];
5
+ outputShape;
6
+ userCode;
7
+ customUniforms = [{
8
+ name: "mode",
9
+ type: "int"
10
+ }];
11
+ constructor(e, t, n, r) {
12
+ let i = r / t;
13
+ this.outputShape = [
14
+ e,
15
+ t,
16
+ n,
17
+ i
18
+ ], this.userCode = `
19
+ void main() {
20
+ ivec4 coords = getOutputCoords(); // [b, h, t, d]
21
+ int b = coords.x;
22
+ int h = coords.y;
23
+ int t = coords.z;
24
+ int d = coords.w;
25
+
26
+ // Compute output channel index in fused kernel
27
+ int out_offset = mode * ${t} * ${i} + h * ${i} + d;
28
+
29
+ float sum = 0.0;
30
+ for (int c = 0; c < ${r}; ++c) {
31
+ float xval = getX(b, t, c); // fetch from x
32
+ float kval = getKernel(c, out_offset); // fetch from kernel
33
+ sum += xval * kval;
34
+ }
35
+
36
+ setOutput(sum);
37
+ }
38
+ `;
39
+ }
40
+ };
41
+ function n(e) {
42
+ let { x: n, kernel: r } = e.inputs, { heads: i } = e.attrs, a = e.backend, o = n.shape[0], s = n.shape[1], c = n.shape[2], l = new t(o, i, s, c);
43
+ return [
44
+ a.runWebGLProgram(l, [n, r], "float32", [[0]]),
45
+ a.runWebGLProgram(l, [n, r], "float32", [[1]]),
46
+ a.runWebGLProgram(l, [n, r], "float32", [[2]])
47
+ ];
48
+ }
49
+ e({
50
+ kernelName: "QKV",
51
+ backendName: "webgl",
52
+ kernelFunc: n
53
+ });
54
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,72 @@
1
+ import { Ii as e } from "../../dist-BewPQWjc.js";
2
+ //#region lib/ops/webgl/rope.ts
3
+ var t = class {
4
+ variableNames = [
5
+ "x",
6
+ "sin",
7
+ "cos"
8
+ ];
9
+ outputShape;
10
+ userCode;
11
+ customUniforms = [{
12
+ name: "pastLen",
13
+ type: "int"
14
+ }];
15
+ constructor(e, t, n, r) {
16
+ this.outputShape = [
17
+ e,
18
+ t,
19
+ n,
20
+ r
21
+ ], this.userCode = `
22
+ void main() {
23
+ ivec4 coords = getOutputCoords(); // [b, h, t, d]
24
+ int b = coords.x;
25
+ int h = coords.y;
26
+ int t = coords.z;
27
+ int d = coords.w;
28
+
29
+ int rotaryDim = ${r};
30
+
31
+ float outVal = 0.0;
32
+
33
+ if (d < rotaryDim) {
34
+ int pairIdx = d / 2;
35
+ float cos = getCos(t + pastLen, pairIdx, 0);
36
+ float sin = getSin(t + pastLen, pairIdx, 0);
37
+
38
+ if (d % 2 == 0) {
39
+ // even index
40
+ float even = getX(b, h, t, d);
41
+ float odd = getX(b, h, t, d + 1);
42
+ outVal = even * cos - odd * sin;
43
+ } else {
44
+ // odd index
45
+ float even = getX(b, h, t, d - 1);
46
+ float odd = getX(b, h, t, d);
47
+ outVal = even * sin + odd * cos;
48
+ }
49
+ } else {
50
+ // pass through for non-rotary dims
51
+ outVal = getX(b, h, t, d);
52
+ }
53
+
54
+ setOutput(outVal);
55
+ }
56
+ `;
57
+ }
58
+ };
59
+ function n(e) {
60
+ let { x: n } = e.inputs, { pastLen: r, ropeCache: i, negSin: a } = e.attrs, o = a ? i.getNegSin() : i.getSin(), s = i.getCos(), c = e.backend, l = n.shape[0], u = n.shape[1], d = n.shape[2], f = n.shape[3], p = new t(l, u, d, f);
61
+ return c.runWebGLProgram(p, [
62
+ n,
63
+ o,
64
+ s
65
+ ], "float32", [[r]]);
66
+ }
67
+ e({
68
+ kernelName: "Rope",
69
+ backendName: "webgl",
70
+ kernelFunc: n
71
+ });
72
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,28 @@
1
+ import { Ii as e } from "../../dist-BewPQWjc.js";
2
+ //#region lib/ops/webgl/scatterSub.ts
3
+ var t = class {
4
+ variableNames = [
5
+ "labels",
6
+ "softmaxProbs",
7
+ "dy"
8
+ ];
9
+ outputShape;
10
+ userCode;
11
+ constructor(e, t) {
12
+ this.outputShape = [e, t], this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = int(getLabels(coords.x));\n float prob = getSoftmaxProbsAtOutCoords();\n float dy = getDy(coords.x);\n setOutput((index == coords.y ? prob - 1.0 : prob) * dy);\n }\n ";
13
+ }
14
+ };
15
+ function n(e) {
16
+ let { logits: n, labels: r, dy: i } = e.inputs, a = e.backend, o = r.shape[0], s = n.shape[1], c = new t(o, s);
17
+ return a.runWebGLProgram(c, [
18
+ r,
19
+ n,
20
+ i
21
+ ], "float32");
22
+ }
23
+ e({
24
+ kernelName: "EfficientScatterSub",
25
+ backendName: "webgl",
26
+ kernelFunc: n
27
+ });
28
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,77 @@
1
+ import { Fs as e, Ii as t } from "../../dist-BewPQWjc.js";
2
+ import { s as n } from "../../webgpu_program-WOyIVMlZ.js";
3
+ import { c as r, i } from "../../webgpu_util-B_F3SShA.js";
4
+ //#region lib/ops/webgpu/adamAdjust.ts
5
+ var a = class {
6
+ variableNames = ["moments", "value"];
7
+ outputShape;
8
+ shaderKey = "AdamAdjust";
9
+ dispatchLayout;
10
+ dispatch;
11
+ workgroupSize = [
12
+ 64,
13
+ 1,
14
+ 1
15
+ ];
16
+ size = !0;
17
+ uniforms = "invbeta1: f32, invbeta2: f32, learningRate: f32, epsilon: f32";
18
+ outputComponent = 1;
19
+ variableComponents = [2, 1];
20
+ useWeightDecay;
21
+ constructor(e, t) {
22
+ this.outputShape = e, this.dispatchLayout = r(this.outputShape), this.dispatch = i(this.dispatchLayout, this.outputShape, this.workgroupSize), this.useWeightDecay = t, t && (this.uniforms += ", weightDecay: f32");
23
+ }
24
+ getUserCode() {
25
+ return `
26
+ ${n("index")} {
27
+ if (index < uniforms.size) {
28
+ let moments: vec2<f32> = moments[index];
29
+ let value: f32 = value[index];
30
+
31
+ let m1Hat = moments.x * uniforms.invbeta1;
32
+ let m2Hat = moments.y * uniforms.invbeta2;
33
+
34
+ let invSqrt = inverseSqrt(max(m2Hat, 1e-30));
35
+ let invDenom = invSqrt / fma(uniforms.epsilon, invSqrt, 1.0);
36
+ var adjustedValue = fma(-uniforms.learningRate * m1Hat, invDenom, value);
37
+
38
+ ${this.useWeightDecay ? "adjustedValue = adjustedValue - uniforms.learningRate * uniforms.weightDecay * value;" : ""}
39
+
40
+ setOutputAtIndex(index, adjustedValue);
41
+ }
42
+ }
43
+ `;
44
+ }
45
+ };
46
+ function o(t) {
47
+ let { moments: n, value: r } = t.inputs, { beta1: i, beta2: o, learningRate: s, epsilon: c, weightDecay: l } = t.attrs, u = t.backend;
48
+ e(n.shape, [...r.shape, 2], "Error in AdamAdjust: ");
49
+ let d = new a(r.shape, l > 0), f = [
50
+ {
51
+ type: "float32",
52
+ data: [1 / i]
53
+ },
54
+ {
55
+ type: "float32",
56
+ data: [1 / o]
57
+ },
58
+ {
59
+ type: "float32",
60
+ data: [s]
61
+ },
62
+ {
63
+ type: "float32",
64
+ data: [c]
65
+ }
66
+ ];
67
+ return l > 0 && f.push({
68
+ type: "float32",
69
+ data: [l]
70
+ }), u.runWebGPUProgram(d, [n, r], "float32", f);
71
+ }
72
+ t({
73
+ kernelName: "AdamAdjust",
74
+ backendName: "webgpu",
75
+ kernelFunc: o
76
+ });
77
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,76 @@
1
+ import { Fs as e, Ii as t } from "../../dist-BewPQWjc.js";
2
+ import { s as n } from "../../webgpu_program-WOyIVMlZ.js";
3
+ import { c as r, i } from "../../webgpu_util-B_F3SShA.js";
4
+ //#region lib/ops/webgpu/adamMoments.ts
5
+ var a = class {
6
+ variableNames = [
7
+ "moments",
8
+ "gradient",
9
+ "scaling"
10
+ ];
11
+ outputShape;
12
+ shaderKey = "AdamMoments";
13
+ dispatchLayout;
14
+ dispatch;
15
+ workgroupSize = [
16
+ 64,
17
+ 1,
18
+ 1
19
+ ];
20
+ size = !0;
21
+ uniforms = "beta1: f32, beta2: f32";
22
+ outputComponent = 2;
23
+ variableComponents = [
24
+ 2,
25
+ 1,
26
+ 1
27
+ ];
28
+ constructor(e) {
29
+ this.outputShape = e, this.dispatchLayout = r(this.outputShape.slice(0, -1)), this.dispatch = i(this.dispatchLayout, this.outputShape.slice(0, -1), this.workgroupSize, [
30
+ 1,
31
+ 1,
32
+ 1
33
+ ]);
34
+ }
35
+ getUserCode() {
36
+ return `
37
+ ${n("index")} {
38
+ if (index < uniforms.size) {
39
+ let m: vec2<f32> = moments[index];
40
+
41
+ // Loss and clip scaling.
42
+ let g: f32 = gradient[index] * scaling[0];
43
+
44
+ let newM1 = fma(m.x, uniforms.beta1, g * (1.0 - uniforms.beta1));
45
+ let newM2 = fma(m.y, uniforms.beta2, g * g * (1.0 - uniforms.beta2));
46
+
47
+ setOutputAtIndex(index, vec2<f32>(newM1, newM2));
48
+ }
49
+ }
50
+ `;
51
+ }
52
+ };
53
+ function o(t) {
54
+ let { moments: n, gradient: r, scaling: i } = t.inputs, { beta1: o, beta2: s } = t.attrs, c = t.backend;
55
+ if (r.dtype !== "float32") throw Error(`Gradient must be float32, but got ${r.dtype}`);
56
+ if (e(n.shape, [...r.shape, 2], "Error in AdamMoments: "), o < 0 || o >= 1) throw Error(`Invalid beta1 value: ${o}. Must be in the range [0, 1).`);
57
+ if (s < 0 || s >= 1) throw Error(`Invalid beta2 value: ${s}. Must be in the range [0, 1).`);
58
+ let l = new a(n.shape), u = [{
59
+ type: "float32",
60
+ data: [o]
61
+ }, {
62
+ type: "float32",
63
+ data: [s]
64
+ }];
65
+ return c.runWebGPUProgram(l, [
66
+ n,
67
+ r,
68
+ i
69
+ ], "float32", u);
70
+ }
71
+ t({
72
+ kernelName: "AdamMoments",
73
+ backendName: "webgpu",
74
+ kernelFunc: o
75
+ });
76
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,14 @@
1
+ import { Ii as e } from "../../dist-BewPQWjc.js";
2
+ import { t } from "../../binary_op_util-CrYk9LXL.js";
3
+ import { BinaryOpProgram as n } from "./utils/binary_op.js";
4
+ //#region lib/ops/webgpu/add16.ts
5
+ function r(e) {
6
+ let { a: r, b: i } = e.inputs, a = e.backend, o = new n(t.ADD, r.shape, i.shape);
7
+ return a.runWebGPUProgram(o, [r, i], "packedF16");
8
+ }
9
+ e({
10
+ kernelName: "Add16",
11
+ backendName: "webgpu",
12
+ kernelFunc: r
13
+ });
14
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,130 @@
1
+ import { Fs as e, Ii as t } from "../../dist-BewPQWjc.js";
2
+ import { isPackedTensor as n } from "../../utilities/packed.js";
3
+ import { s as r } from "../../webgpu_program-WOyIVMlZ.js";
4
+ import { c as i, i as a } from "../../webgpu_util-B_F3SShA.js";
5
+ //#region lib/ops/webgpu/appendCache.ts
6
+ var o = class {
7
+ variableNames = ["cache", "item"];
8
+ outputShape;
9
+ shaderKey = "AppendCache";
10
+ dispatchLayout;
11
+ dispatch;
12
+ workgroupSize = [
13
+ 64,
14
+ 1,
15
+ 1
16
+ ];
17
+ size = !0;
18
+ uniforms = "cacheT: i32";
19
+ constructor(e, t, n, r, o) {
20
+ let s = Math.min(n + 1, o);
21
+ this.shaderKey = `AppendCache_${s}`, this.outputShape = [
22
+ e,
23
+ t,
24
+ s,
25
+ r
26
+ ], this.dispatchLayout = i(this.outputShape), this.dispatch = a(this.dispatchLayout, this.outputShape, this.workgroupSize);
27
+ }
28
+ getUserCode() {
29
+ let e = this.outputShape[2];
30
+ return `
31
+ ${r("index")} {
32
+ if (index < uniforms.size) {
33
+ let coords = getCoordsFromIndex(index); // [b, h, t, d]
34
+ let b = coords[0];
35
+ let h = coords[1];
36
+ let t = coords[2];
37
+ let d = coords[3];
38
+
39
+ let itemT = 1;
40
+ let maxSize = ${e};
41
+ let totalT = uniforms.cacheT + itemT;
42
+ let start = select(0, 1, totalT >= maxSize);
43
+
44
+ let srcT = t + start;
45
+ var val = 0.0;
46
+ if (srcT < uniforms.cacheT) {
47
+ val = getCache(b, h, srcT, d);
48
+ }
49
+ if (srcT == uniforms.cacheT) {
50
+ val = getItem(b, h, 0, d);
51
+ }
52
+
53
+ setOutputAtIndex(index, val);
54
+ }
55
+ }
56
+ `;
57
+ }
58
+ }, s = class {
59
+ variableNames = ["cache", "item"];
60
+ outputShape;
61
+ shaderKey = "AppendCache";
62
+ dispatchLayout;
63
+ dispatch;
64
+ workgroupSize = [
65
+ 64,
66
+ 1,
67
+ 1
68
+ ];
69
+ size = !0;
70
+ uniforms = "cacheT: i32";
71
+ constructor(e, t, n, r, o) {
72
+ let s = Math.min(n + 1, o);
73
+ this.shaderKey = `AppendCache_${s}`, this.outputShape = [
74
+ e,
75
+ t,
76
+ s,
77
+ r
78
+ ], this.dispatchLayout = i(this.outputShape), this.dispatch = a(this.dispatchLayout, this.outputShape, this.workgroupSize);
79
+ }
80
+ getUserCode() {
81
+ let e = this.outputShape[2];
82
+ return `
83
+ ${r("index")} {
84
+ if (index < uniforms.size) {
85
+ let coords = getCoordsFromIndex(index); // [b, h, t, d]
86
+ let b = coords[0];
87
+ let h = coords[1];
88
+ let t = coords[2];
89
+ let d = coords[3];
90
+
91
+ let itemT = 1;
92
+ let maxSize = ${e};
93
+ let totalT = uniforms.cacheT + itemT;
94
+ let start = select(0, 1, totalT >= maxSize);
95
+
96
+ let srcT = t + start;
97
+ var val: i32 = 0i;
98
+ if (srcT < uniforms.cacheT) {
99
+ val = cache[getIndexFromCoords4D(vec4<i32>(b, h, srcT, d), uniforms.cacheShape)];
100
+ }
101
+ if (srcT == uniforms.cacheT) {
102
+ val = item[getIndexFromCoords4D(vec4<i32>(b, h, 0, d), uniforms.itemShape)];
103
+ }
104
+
105
+ result[index] = val;
106
+ }
107
+ }
108
+ `;
109
+ }
110
+ };
111
+ function c(t) {
112
+ let { cache: r, item: i } = t.inputs, { maxSize: a, pastLen: c } = t.attrs, l = t.backend, u = n(r), d = r.shape[0], f = r.shape[2], p = r.shape[1];
113
+ if (e(i.shape, [
114
+ d,
115
+ p,
116
+ 1,
117
+ i.shape[3]
118
+ ], "Error in AppendCache: "), c < 0 || c > a) throw Error(`Invalid pastLen value: ${c}. Must be in the range [0, ${a}].`);
119
+ let m = u ? new s(d, p, f, i.shape[3], a) : new o(d, p, f, i.shape[3], a), h = [{
120
+ type: "int32",
121
+ data: [c]
122
+ }], g = u ? "packedF16" : r.dtype;
123
+ return l.runWebGPUProgram(m, [r, i], g, h);
124
+ }
125
+ t({
126
+ kernelName: "AppendCache",
127
+ backendName: "webgpu",
128
+ kernelFunc: c
129
+ });
130
+ //#endregion
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,42 @@
1
+ import { Fs as e, Ii as t } from "../../dist-BewPQWjc.js";
2
+ import { isPackedTensor as n } from "../../utilities/packed.js";
3
+ import { t as r } from "../../matMul16-BNfZSnNM.js";
4
+ import i from "./attentionMask32_program.js";
5
+ //#region lib/ops/webgpu/attentionMask.ts
6
+ function a(t) {
7
+ let { q: a, k: o } = t.inputs, { divisor: s, pastLen: c } = t.attrs, l = t.backend;
8
+ if (n(a) && n(o)) return r(a, o, !1, !0, {
9
+ causalMask: !0,
10
+ pastLen: c,
11
+ scale: s
12
+ });
13
+ let u = a.shape[0], d = a.shape[2], f = o.shape[2], p = a.shape[1], m = a.shape[3];
14
+ if (e(o.shape, [
15
+ u,
16
+ p,
17
+ f,
18
+ m
19
+ ], "Error in AttentionMask: "), s === 0) throw Error("Divisor must be non-zero in AttentionMask");
20
+ if (c < 0) throw Error("pastLen must be non-negative in AttentionMask");
21
+ let h = new i(u, p, d, f, m), g = [
22
+ {
23
+ type: "float32",
24
+ data: [s]
25
+ },
26
+ {
27
+ type: "int32",
28
+ data: [c]
29
+ },
30
+ {
31
+ type: "float32",
32
+ data: [-Infinity]
33
+ }
34
+ ], _ = a.dtype;
35
+ return l.runWebGPUProgram(h, [a, o], _, g);
36
+ }
37
+ t({
38
+ kernelName: "AttentionMask",
39
+ backendName: "webgpu",
40
+ kernelFunc: a
41
+ });
42
+ //#endregion
@@ -0,0 +1,19 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
2
+ export default class AttentionMaskProgram32 implements WebGPUProgram {
3
+ variableNames: string[];
4
+ outputShape: number[];
5
+ shaderKey: string;
6
+ dispatchLayout: {
7
+ x: number[];
8
+ };
9
+ dispatch: [number, number, number];
10
+ uniforms: string;
11
+ workgroupSize: [number, number, number];
12
+ size: boolean;
13
+ hs: number;
14
+ nh: number;
15
+ T1: number;
16
+ T2: number;
17
+ constructor(batch: number, nh: number, T1: number, T2: number, hs: number);
18
+ getUserCode(): string;
19
+ }