@genai-fi/nanogpt 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +82 -64
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -1,71 +1,28 @@
1
- import { f, a4 as m } from "../../index-BzFyqcy-.js";
2
- import { g as k } from "../../webgpu_program-DkQJOJSd.js";
3
- import { f as l, c as v } from "../../webgpu_util-pLEV9tks.js";
4
- class g {
5
- variableNames = ["q", "k"];
6
- outputShape;
7
- shaderKey = "AttentionMask";
8
- dispatchLayout;
9
- dispatch;
10
- uniforms = "divisor: f32, pastLen: i32, inf: f32";
11
- workgroupSize = [64, 1, 1];
12
- size = !0;
13
- hs;
14
- nh;
15
- T1;
16
- T2;
17
- constructor(t, e, s, o, i) {
18
- if (this.shaderKey = `AttentionMask_${i}`, this.outputShape = [t, e, s, o], this.hs = i, this.nh = e, this.T1 = s, this.T2 = o, this.dispatchLayout = l(this.outputShape), this.dispatch = v(this.dispatchLayout, this.outputShape, this.workgroupSize), i % 4 !== 0)
19
- throw new Error("Head size must be a multiple of 4 for AttentionMaskProgram");
20
- }
21
- getUserCode() {
22
- return `
23
- ${k("index")} {
24
-
25
- let coords = getCoordsFromIndex(index);
26
- let b = coords[0];
27
- let h = coords[1];
28
- let t1 = coords[2];
29
- let t2 = coords[3];
30
-
31
- if (index < uniforms.size) {
32
- if (t2 > t1 + uniforms.pastLen) {
33
- setOutputAtIndex(index, uniforms.inf);
34
- return;
35
- }
36
-
37
- let q0 = getIndexFromCoords4D(vec4<i32>(b, h, t1, 0), uniforms.qShape);
38
- let k0 = getIndexFromCoords4D(vec4<i32>(b, h, t2, 0), uniforms.kShape);
39
-
40
- var sum: f32 = 0.0;
41
- for (var i: i32 = 0; i < ${this.hs}; i = i + 4) {
42
- let qv = vec4<f32>(q[q0 + i], q[q0 + i + 1], q[q0 + i + 2], q[q0 + i + 3]);
43
- let kv = vec4<f32>(k[k0 + i], k[k0 + i + 1], k[k0 + i + 2], k[k0 + i + 3]);
44
- sum = sum + dot(qv, kv);
45
- }
46
- let scaled = sum * uniforms.divisor;
47
- setOutputAtIndex(index, scaled);
48
- }
49
- }
50
- `;
51
- }
52
- }
53
- function b(n) {
54
- const { q: t, k: e } = n.inputs, { divisor: s, pastLen: o } = n.attrs, i = n.backend, r = t.shape[0], p = t.shape[2], a = e.shape[2], u = t.shape[1], h = t.shape[3];
55
- if (m(e.shape, [r, u, a, h], "Error in AttentionMask: "), s === 0)
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { j as d } from "../../tensor-DdQUJZlz.js";
3
+ import { isPackedTensor as p } from "../../utilities/packed.js";
4
+ import { b } from "../../matMul16--R5hOwDG.js";
5
+ import l from "./attentionMask32_program.js";
6
+ import { r as M } from "../../tensor_util-DV-FP5Q3.js";
7
+ function w(n) {
8
+ const { q: t, k: e } = n.inputs, { divisor: a, pastLen: o } = n.attrs, m = n.backend;
9
+ if (p(t) && p(e))
10
+ return b(t, e, !1, !0, { causalMask: !0, pastLen: o, scale: a });
11
+ const r = t.shape[0], k = t.shape[2], s = e.shape[2], i = t.shape[1], c = t.shape[3];
12
+ if (d(e.shape, [r, i, s, c], "Error in AttentionMask: "), a === 0)
56
13
  throw new Error("Divisor must be non-zero in AttentionMask");
57
14
  if (o < 0)
58
15
  throw new Error("pastLen must be non-negative in AttentionMask");
59
- const c = new g(r, u, p, a, h), d = [
60
- { type: "float32", data: [s] },
16
+ const u = new l(r, i, k, s, c), f = [
17
+ { type: "float32", data: [a] },
61
18
  { type: "int32", data: [o] },
62
19
  { type: "float32", data: [Number.NEGATIVE_INFINITY] }
63
- ];
64
- return i.runWebGPUProgram(c, [t, e], "float32", d);
20
+ ], h = t.dtype;
21
+ return m.runWebGPUProgram(u, [t, e], h, f);
65
22
  }
66
- const q = {
23
+ const A = {
67
24
  kernelName: "AttentionMask",
68
25
  backendName: "webgpu",
69
- kernelFunc: b
26
+ kernelFunc: w
70
27
  };
71
- f(q);
28
+ M(A);
@@ -0,0 +1,19 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
2
+ export default class AttentionMaskProgram32 implements WebGPUProgram {
3
+ variableNames: string[];
4
+ outputShape: number[];
5
+ shaderKey: string;
6
+ dispatchLayout: {
7
+ x: number[];
8
+ };
9
+ dispatch: [number, number, number];
10
+ uniforms: string;
11
+ workgroupSize: [number, number, number];
12
+ size: boolean;
13
+ hs: number;
14
+ nh: number;
15
+ T1: number;
16
+ T2: number;
17
+ constructor(batch: number, nh: number, T1: number, T2: number, hs: number);
18
+ getUserCode(): string;
19
+ }
@@ -0,0 +1,54 @@
1
+ import { e as r } from "../../webgpu_program-Cigz-7RF.js";
2
+ import { f as a, c as u } from "../../webgpu_util-BBCnKm2X.js";
3
+ class p {
4
+ variableNames = ["q", "k"];
5
+ outputShape;
6
+ shaderKey = "AttentionMask";
7
+ dispatchLayout;
8
+ dispatch;
9
+ uniforms = "divisor: f32, pastLen: i32, inf: f32";
10
+ workgroupSize = [64, 1, 1];
11
+ size = !0;
12
+ hs;
13
+ nh;
14
+ T1;
15
+ T2;
16
+ constructor(e, i, s, o, t) {
17
+ if (this.shaderKey = `AttentionMask_${t}`, this.outputShape = [e, i, s, o], this.hs = t, this.nh = i, this.T1 = s, this.T2 = o, this.dispatchLayout = a(this.outputShape), this.dispatch = u(this.dispatchLayout, this.outputShape, this.workgroupSize), t % 4 !== 0)
18
+ throw new Error("Head size must be a multiple of 4 for AttentionMaskProgram");
19
+ }
20
+ getUserCode() {
21
+ return `
22
+ ${r("index")} {
23
+
24
+ let coords = getCoordsFromIndex(index);
25
+ let b = coords[0];
26
+ let h = coords[1];
27
+ let t1 = coords[2];
28
+ let t2 = coords[3];
29
+
30
+ if (index < uniforms.size) {
31
+ if (t2 > t1 + uniforms.pastLen) {
32
+ setOutputAtIndex(index, uniforms.inf);
33
+ return;
34
+ }
35
+
36
+ let q0 = getIndexFromCoords4D(vec4<i32>(b, h, t1, 0), uniforms.qShape);
37
+ let k0 = getIndexFromCoords4D(vec4<i32>(b, h, t2, 0), uniforms.kShape);
38
+
39
+ var sum: f32 = 0.0;
40
+ for (var i: i32 = 0; i < ${this.hs}; i = i + 4) {
41
+ let qv = vec4<f32>(q[q0 + i], q[q0 + i + 1], q[q0 + i + 2], q[q0 + i + 3]);
42
+ let kv = vec4<f32>(k[k0 + i], k[k0 + i + 1], k[k0 + i + 2], k[k0 + i + 3]);
43
+ sum = sum + dot(qv, kv);
44
+ }
45
+ let scaled = sum * uniforms.divisor;
46
+ setOutputAtIndex(index, scaled);
47
+ }
48
+ }
49
+ `;
50
+ }
51
+ }
52
+ export {
53
+ p as default
54
+ };
@@ -0,0 +1,19 @@
1
+ import { KernelConfig } from '@tensorflow/tfjs-core';
2
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu/dist/webgpu_program';
3
+ export declare class ConcatProgram implements WebGPUProgram {
4
+ outputShape: number[];
5
+ shaderKey: string;
6
+ dispatchLayout: {
7
+ x: number[];
8
+ };
9
+ dispatch: [number, number, number];
10
+ variableNames: string[];
11
+ uniforms: string;
12
+ workPerThread: number;
13
+ workgroupSize: [number, number, number];
14
+ size: boolean;
15
+ offsetLength: number;
16
+ constructor(shapes: Array<[number, number]>);
17
+ getUserCode(): string;
18
+ }
19
+ export declare const concatConfig: KernelConfig;
@@ -0,0 +1,128 @@
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { e as x } from "../../webgpu_program-Cigz-7RF.js";
3
+ import { f as I, c as D } from "../../webgpu_util-BBCnKm2X.js";
4
+ import { r as y } from "../../Reshape-CF6odzV4.js";
5
+ import { r as $ } from "../../tensor_util-DV-FP5Q3.js";
6
+ import { p as F, s as c } from "../../tensor-DdQUJZlz.js";
7
+ import { a as L, c as d } from "../../concat_util-DcJk7YHS.js";
8
+ class T {
9
+ outputShape;
10
+ shaderKey;
11
+ dispatchLayout;
12
+ dispatch;
13
+ variableNames;
14
+ uniforms = "";
15
+ workPerThread = 1;
16
+ workgroupSize = [64, 1, 1];
17
+ size = !0;
18
+ offsetLength;
19
+ constructor(t) {
20
+ this.outputShape = d(
21
+ t,
22
+ 1
23
+ /* axis */
24
+ ), this.variableNames = t.map((e, a) => `T${a}`), this.dispatchLayout = I(this.outputShape), this.dispatch = D(this.dispatchLayout, this.outputShape, this.workgroupSize, [
25
+ this.workPerThread,
26
+ 1,
27
+ 1
28
+ ]), this.offsetLength = t.length - 1;
29
+ for (let e = 0; e < this.offsetLength; e++)
30
+ this.uniforms += `offset${e} : i32,`;
31
+ this.shaderKey = "concat16";
32
+ }
33
+ getUserCode() {
34
+ const t = [];
35
+ if (this.offsetLength > 0) {
36
+ t.push(
37
+ "if (yC < uniforms.offset0){ result[getIndexFromCoords2D(coords, uniforms.outShape)] = T0[getIndexFromCoords2D(vec2<i32>(yR, yC), uniforms.t0Shape)]; }"
38
+ );
39
+ for (let s = 1; s < this.offsetLength; s++)
40
+ t.push(
41
+ `else if (yC < uniforms.offset${[s]}){ result[getIndexFromCoords2D(coords, uniforms.outShape)] = T${s}[getIndexFromCoords2D(vec2<i32>(yR, yC - uniforms.offset${s - 1}), uniforms.t${s}Shape)]; }`
42
+ );
43
+ const a = this.offsetLength, i = this.offsetLength - 1;
44
+ t.push(
45
+ `else { result[getIndexFromCoords2D(coords, uniforms.outShape)] = T${a}[getIndexFromCoords2D(vec2<i32>(yR, yC - uniforms.offset${i}), uniforms.t${a}Shape)]; }`
46
+ );
47
+ } else
48
+ t.push(
49
+ "result[getIndexFromCoords2D(coords, uniforms.outShape)] = T0[getIndexFromCoords2D(vec2<i32>(yR, yC), uniforms.t0Shape)];"
50
+ );
51
+ return `
52
+ ${x("index")} {
53
+ for(var i = 0; i < ${this.workPerThread}; i = i + 1) {
54
+ let flatIndex = index * ${this.workPerThread} + i;
55
+ if(flatIndex < uniforms.size) {
56
+ let coords = getCoordsFromIndex(flatIndex);
57
+ let yR = coords.x;
58
+ let yC = coords.y;
59
+
60
+ ${t.join(`
61
+ `)}
62
+ }
63
+ }
64
+ }
65
+ `;
66
+ }
67
+ }
68
+ function m(n, t, e) {
69
+ const a = e.device.limits.maxStorageBuffersPerShaderStage - 1;
70
+ if (n.length > a) {
71
+ const o = [];
72
+ for (let p = 0; p < n.length; p += a) {
73
+ const C = n.slice(p, p + a);
74
+ o.push(m(C, t, e));
75
+ }
76
+ const S = m(o, t, e);
77
+ for (const p of o)
78
+ e.disposeData(p.dataId);
79
+ return S;
80
+ }
81
+ const { tensors2D: i, outShape: s } = P(n, t, e), h = i.map((o) => o.shape), u = new T(h), f = [], r = new Array(h.length - 1);
82
+ if (r.length > 0) {
83
+ r[0] = h[0][1], f.push({ type: "int32", data: [r[0]] });
84
+ for (let o = 1; o < r.length; o++)
85
+ r[o] = r[o - 1] + h[o][1], f.push({ type: "int32", data: [r[o]] });
86
+ }
87
+ const l = e.runWebGPUProgram(u, i, i[0].dtype, f);
88
+ i.forEach((o) => e.disposeData(o.dataId));
89
+ const g = y({ inputs: { x: l }, backend: e, attrs: { shape: s } });
90
+ return e.disposeData(l.dataId), g.packed = !0, g;
91
+ }
92
+ function P(n, t, e) {
93
+ const a = d(
94
+ n.map((s) => s.shape),
95
+ t
96
+ );
97
+ return { tensors2D: n.map(
98
+ (s) => y({
99
+ inputs: { x: s },
100
+ backend: e,
101
+ attrs: {
102
+ shape: [c(s.shape.slice(0, t)), c(s.shape.slice(t))]
103
+ }
104
+ })
105
+ ), outShape: a };
106
+ }
107
+ function w(n) {
108
+ const { inputs: t, backend: e, attrs: a } = n, { axis: i } = a, s = F(i, t[0].shape)[0], h = t.map((r) => r.shape);
109
+ L(h, s);
110
+ const u = d(
111
+ t.map((r) => r.shape),
112
+ s
113
+ );
114
+ if (c(u) === 0)
115
+ return e.makeTensorInfo(u, t[0].dtype, []);
116
+ const f = t.filter((r) => c(r.shape) > 0);
117
+ return m(f, s, e);
118
+ }
119
+ const v = {
120
+ kernelName: "Concat16",
121
+ backendName: "webgpu",
122
+ kernelFunc: w
123
+ };
124
+ $(v);
125
+ export {
126
+ T as ConcatProgram,
127
+ v as concatConfig
128
+ };
@@ -1,6 +1,8 @@
1
- import { g as u } from "../../webgpu_program-DkQJOJSd.js";
2
- import { f as h, c as p } from "../../webgpu_util-pLEV9tks.js";
3
- import { f as c, a4 as r } from "../../index-BzFyqcy-.js";
1
+ import { e as u } from "../../webgpu_program-Cigz-7RF.js";
2
+ import { f as p, c as h } from "../../webgpu_util-BBCnKm2X.js";
3
+ import "../../index-ZyQhjEPo.js";
4
+ import { j as s } from "../../tensor-DdQUJZlz.js";
5
+ import { r as c } from "../../tensor_util-DV-FP5Q3.js";
4
6
  class l {
5
7
  variableNames = ["labels", "logits", "values"];
6
8
  outputShape;
@@ -10,7 +12,7 @@ class l {
10
12
  workgroupSize = [64, 1, 1];
11
13
  size = !0;
12
14
  constructor(e) {
13
- this.outputShape = [e], this.dispatchLayout = h(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize);
15
+ this.outputShape = [e], this.dispatchLayout = p(this.outputShape), this.dispatch = h(this.dispatchLayout, this.outputShape, this.workgroupSize);
14
16
  }
15
17
  getUserCode() {
16
18
  return `
@@ -26,10 +28,10 @@ class l {
26
28
  }
27
29
  }
28
30
  function d(t) {
29
- const { logits: e, labels: a, values: s } = t.inputs, o = t.backend, i = a.shape[0];
30
- r(s.shape, [i], "Error in EfficientGatherSub: "), r(a.shape, [i], "Error in EfficientGatherSub: ");
31
+ const { logits: e, labels: a, values: r } = t.inputs, o = t.backend, i = a.shape[0];
32
+ s(r.shape, [i], "Error in EfficientGatherSub: "), s(a.shape, [i], "Error in EfficientGatherSub: ");
31
33
  const n = new l(i);
32
- return o.runWebGPUProgram(n, [a, e, s], "float32");
34
+ return o.runWebGPUProgram(n, [a, e, r], "float32");
33
35
  }
34
36
  const f = {
35
37
  kernelName: "EfficientGatherSub",
@@ -1,8 +1,10 @@
1
- import { f as s } from "../../index-BzFyqcy-.js";
2
- import { g as a } from "../../webgpu_program-DkQJOJSd.js";
3
- import { f as o, c as p } from "../../webgpu_util-pLEV9tks.js";
4
- const u = 0.7978845608028654, i = 0.044715;
5
- class d {
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { e as s } from "../../webgpu_program-Cigz-7RF.js";
3
+ import { f as o, c as p } from "../../webgpu_util-BBCnKm2X.js";
4
+ import { isPackedTensor as l } from "../../utilities/packed.js";
5
+ import { r as h } from "../../tensor_util-DV-FP5Q3.js";
6
+ const r = 0.7978845608028654, u = 0.044715;
7
+ class x {
6
8
  outputShape;
7
9
  shaderKey;
8
10
  dispatchLayout;
@@ -21,13 +23,13 @@ class d {
21
23
  }
22
24
  fn unaryOperation(x : f32) -> f32 {
23
25
  let x3 = x * x * x;
24
- var inner = fma(${i}, x3, x);
25
- inner = ${u} * inner;
26
+ var inner = fma(${u}, x3, x);
27
+ inner = ${r} * inner;
26
28
  inner = tanhComplete(inner);
27
29
  inner = 0.5 * (1.0 + inner);
28
30
  return x * inner;
29
31
  }
30
- ${a("index")} {
32
+ ${s("index")} {
31
33
  if (index < uniforms.size) {
32
34
  let a = getAByOutputIndex(index);
33
35
  setOutputAtIndex(index, unaryOperation(a));
@@ -36,17 +38,58 @@ class d {
36
38
  `;
37
39
  }
38
40
  }
39
- function c(t) {
40
- const { x: e } = t.inputs, n = t.backend, r = new d(e.shape);
41
- return n.runWebGPUProgram(r, [e], "float32");
41
+ function g(t) {
42
+ const { x: e } = t.inputs, a = t.backend, i = new x(e.shape);
43
+ return a.runWebGPUProgram(i, [e], "float32");
42
44
  }
43
- const l = {
45
+ const f = {
44
46
  kernelName: "Gelu",
45
47
  backendName: "webgpu",
46
- kernelFunc: c
48
+ kernelFunc: g
47
49
  };
48
- s(l);
49
- class x {
50
+ h(f);
51
+ class m {
52
+ // Inputs: dy, x
53
+ variableNames = ["dy", "x"];
54
+ outputShape;
55
+ shaderKey = "GeluGrad";
56
+ dispatchLayout;
57
+ dispatch;
58
+ workgroupSize = [128, 1, 1];
59
+ size = !0;
60
+ constructor(e) {
61
+ this.outputShape = e, this.dispatchLayout = o(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize);
62
+ }
63
+ getUserCode() {
64
+ return `
65
+ // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
66
+ fn tanhComplete(x: f32) -> f32 {
67
+ return select(tanh(x), sign(x), abs(x) > 15.0);
68
+ }
69
+ fn activationGrad(dy: f32, X: f32) -> f32 {
70
+ let x2 = X * X;
71
+ let x3 = x2 * X;
72
+ let u = ${r} * (X + ${u} * x3);
73
+ let t = tanhComplete(u);
74
+ let sech2 = 1.0 - t * t;
75
+ let du_dx = ${r} * (1.0 + 3.0 * ${u} * x2);
76
+ let dgelu = 0.5 * (1.0 + t) + 0.5 * X * sech2 * du_dx;
77
+ return dy *dgelu;
78
+ }
79
+ ${s("index")} {
80
+ if (index < uniforms.size) {
81
+ let X = unpack2x16float(u32(x[index]));
82
+ let DY = unpack2x16float(u32(dy[index]));
83
+ let dgelu = vec2<f32>(
84
+ activationGrad(DY.x, X.x),
85
+ activationGrad(DY.y, X.y)
86
+ );
87
+ result[index] = i32(pack2x16float(dgelu));
88
+ }
89
+ }`;
90
+ }
91
+ }
92
+ class y {
50
93
  // Inputs: dy, x
51
94
  variableNames = ["dy", "x"];
52
95
  outputShape;
@@ -64,32 +107,36 @@ class x {
64
107
  fn tanhComplete(x: f32) -> f32 {
65
108
  return select(tanh(x), sign(x), abs(x) > 15.0);
66
109
  }
67
- ${a("index")} {
110
+ fn activationGrad(dy: f32, X: f32) -> f32 {
111
+ let x2 = X * X;
112
+ let x3 = x2 * X;
113
+ let u = ${r} * (X + ${u} * x3);
114
+ let t = tanhComplete(u);
115
+ let sech2 = 1.0 - t * t;
116
+ let du_dx = ${r} * (1.0 + 3.0 * ${u} * x2);
117
+ let dgelu = 0.5 * (1.0 + t) + 0.5 * X * sech2 * du_dx;
118
+ return dy *dgelu;
119
+ }
120
+ ${s("index")} {
68
121
  if (index < uniforms.size) {
69
122
  let X = getXByOutputIndex(index);
70
- let x2 = X * X;
71
- let x3 = x2 * X;
72
- let u = ${u} * (X + ${i} * x3);
73
- let t = tanhComplete(u);
74
- let sech2 = 1.0 - t * t;
75
- let du_dx = ${u} * (1.0 + 3.0 * ${i} * x2);
76
- let dgelu = 0.5 * (1.0 + t) + 0.5 * X * sech2 * du_dx;
77
123
  let DY = getDyByOutputIndex(index);
78
- setOutputAtIndex(index, DY * dgelu);
124
+ let dgelu = activationGrad(DY, X);
125
+ setOutputAtIndex(index, dgelu);
79
126
  }
80
127
  }`;
81
128
  }
82
129
  }
83
- function g(t) {
84
- const { dy: e, x: n } = t.inputs, r = t.backend, h = new x(n.shape);
85
- return r.runWebGPUProgram(h, [e, n], "float32");
130
+ function b(t) {
131
+ const { dy: e, x: a } = t.inputs, i = t.backend, n = l(e), c = n ? new m(a.shape) : new y(a.shape), d = i.runWebGPUProgram(c, [e, a], n ? "int32" : "float32");
132
+ return d.packed = n, d;
86
133
  }
87
- const f = {
134
+ const k = {
88
135
  kernelName: "GeluGrad",
89
136
  backendName: "webgpu",
90
- kernelFunc: g
137
+ kernelFunc: b
91
138
  };
92
- s(f);
139
+ h(k);
93
140
  export {
94
- d as GeluProgram
141
+ x as GeluProgram
95
142
  };
@@ -9,3 +9,15 @@ import "./qkv.js";
9
9
  import "./gelu.js";
10
10
  import "./adamMoments.js";
11
11
  import "./adamAdjust.js";
12
+ import "./pack16.js";
13
+ import "./unpack16.js";
14
+ import "./softmax16.js";
15
+ import "./matMul16.js";
16
+ import "./transpose16.js";
17
+ import "./sum16.js";
18
+ import "./slice16.js";
19
+ import "./add16.js";
20
+ import "./concat16.js";
21
+ import "./mul16.js";
22
+ import "./sub16.js";
23
+ import "./softmax16grad.js";
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,58 @@
1
+ import { m as y, b as B, j as Q } from "../../index-ZyQhjEPo.js";
2
+ import { isPackedTensor as R } from "../../utilities/packed.js";
3
+ import { reshape16 as U } from "../reshape16.js";
4
+ import { matMulMul as V } from "../matMulMul.js";
5
+ import { matMulGelu as X } from "../matMulGelu.js";
6
+ import Y from "./matMul16_program.js";
7
+ import { r as Z } from "../../tensor_util-DV-FP5Q3.js";
8
+ import { m as _ } from "../../mat_mul-DeAh4uTH.js";
9
+ import { r as x } from "../../reshape-DevtBWtf.js";
10
+ import { t as C } from "../../transpose-DKELTqhe.js";
11
+ import { s as E } from "../../tensor-DdQUJZlz.js";
12
+ function $(p) {
13
+ const { A: e, B: s } = p.inputs, { transposeA: d, transposeB: f, scale: i, activation: k, scaleA: c, scaleB: u, forceOutputShape: o, perm: m, causalMask: g, pastLen: W } = p.attrs, z = p.backend, S = !R(e), M = !R(s);
14
+ if (S && M) {
15
+ const A = c !== void 0 ? y(e, B(c)) : e, b = u !== void 0 ? y(s, B(u)) : s;
16
+ if (g)
17
+ throw new Error("Causal mask is not supported for unpacked MatMul16.");
18
+ let a;
19
+ if (i !== void 0 ? a = V(A, b, B(i), d, f) : k === "gelu" ? a = X(A, b) : a = _(A, b, d, f), m)
20
+ if (o) {
21
+ const n = x(a, o);
22
+ a.dispose();
23
+ const J = C(n, m);
24
+ return n.dispose(), J;
25
+ } else {
26
+ const n = C(a, m);
27
+ return a.dispose(), n;
28
+ }
29
+ else if (o) {
30
+ const n = x(a, o);
31
+ return a.dispose(), n;
32
+ } else
33
+ return a;
34
+ }
35
+ if (S && !M)
36
+ throw new Error("When using mixed precision, A must be packed if B is packed.");
37
+ if (!S && M)
38
+ throw new Error("When using mixed precision, B must be packed if A is packed.");
39
+ const h = e.shape.length, l = s.shape.length, F = e.shape.slice(0, -2), I = s.shape.slice(0, -2), v = E(F), w = E(I), N = Q(e.shape.slice(0, -2), s.shape.slice(0, -2)), j = Math.max(v, w), K = e.shape[h - 2], L = s.shape[l - 2], T = e.shape[h - 1] * 2, q = s.shape[l - 1] * 2, D = U(e, [v, e.shape[h - 2], e.shape[h - 1]]), G = U(s, [w, s.shape[l - 2], s.shape[l - 1]]), t = new Y(j, K, L, T, q, d, f), r = [];
40
+ i !== void 0 && (t.useScale(), r.push({ type: "float32", data: [i] })), c !== void 0 && (t.useScaleA(), r.push({ type: "float32", data: [c] })), u !== void 0 && (t.useScaleB(), r.push({ type: "float32", data: [u] })), k !== void 0 && t.useActivation(k), g && (t.useCausalMask(), r.push({ type: "int32", data: [W || 0] }));
41
+ const O = t.outputShape.length;
42
+ o && (p.attrs.originalShape = t.outputShape);
43
+ const H = o ?? N.concat([t.outputShape[O - 2], t.outputShape[O - 1]]);
44
+ t.setOutputShape(H, m);
45
+ const P = z.runWebGPUProgram(
46
+ t,
47
+ [D, G],
48
+ "int32",
49
+ r.length > 0 ? r : void 0
50
+ );
51
+ return P.packed = !0, D.dispose(), G.dispose(), P;
52
+ }
53
+ const ee = {
54
+ kernelName: "MatMul16",
55
+ backendName: "webgpu",
56
+ kernelFunc: $
57
+ };
58
+ Z(ee);
@@ -0,0 +1,42 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
2
+ export default class MatMul16ProgramGeneric implements WebGPUProgram {
3
+ variableNames: string[];
4
+ outputShape: number[];
5
+ shaderKey: string;
6
+ dispatchLayout: {
7
+ x: number[];
8
+ y: number[];
9
+ z: number[];
10
+ };
11
+ dispatch: [number, number, number];
12
+ workgroupSize: [number, number, number];
13
+ dimInner: number;
14
+ transposeA: boolean;
15
+ transposeB: boolean;
16
+ broadcastBatch: boolean;
17
+ tileInner: number;
18
+ uniforms?: string;
19
+ scale: boolean;
20
+ scaleA: boolean;
21
+ scaleB: boolean;
22
+ activation?: 'gelu';
23
+ causalMask: boolean;
24
+ outputComponent?: number | undefined;
25
+ variableComponents?: number[];
26
+ outputIndexSnippet?: string;
27
+ outputStrideSnippet?: string;
28
+ constructor(batch: number, O1: number, O2: number, I1: number, I2: number, transposeA?: boolean, transposeB?: boolean);
29
+ private addUniform;
30
+ setOutputShape(shape: number[], perm?: number[]): void;
31
+ useScale(): void;
32
+ useScaleA(): void;
33
+ useScaleB(): void;
34
+ useActivation(activation: 'gelu'): void;
35
+ useCausalMask(): void;
36
+ private activationSnippet;
37
+ private readASnippet;
38
+ private readBSnippet;
39
+ private baseIndexSnippets;
40
+ private offsetSnippets;
41
+ getUserCode(): string;
42
+ }