@genai-fi/nanogpt 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +82 -64
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -0,0 +1,71 @@
1
+ import { b as c, c as f, e as l } from "../../webgpu_program-Cigz-7RF.js";
2
+ import { f as g, c as S } from "../../webgpu_util-BBCnKm2X.js";
3
+ import "../../index-ZyQhjEPo.js";
4
+ import { r as k } from "../../tensor_util-DV-FP5Q3.js";
5
+ import { p as y, a as $ } from "../../slice_util-DtEldBfK.js";
6
+ import { s as x } from "../../tensor-DdQUJZlz.js";
7
+ function b(o) {
8
+ switch (o) {
9
+ case 1:
10
+ return "1D";
11
+ case 2:
12
+ return "2D";
13
+ case 3:
14
+ return "3D";
15
+ case 4:
16
+ return "4D";
17
+ }
18
+ return "";
19
+ }
20
+ class w {
21
+ variableNames = ["source"];
22
+ uniforms;
23
+ outputShape;
24
+ shaderKey;
25
+ rank;
26
+ dispatchLayout;
27
+ dispatch;
28
+ workPerThread = 1;
29
+ workgroupSize = [64, 1, 1];
30
+ start;
31
+ size = !0;
32
+ constructor(e, t) {
33
+ this.outputShape = t, this.rank = t.length, this.dispatchLayout = g(this.outputShape), this.dispatch = S(this.dispatchLayout, this.outputShape, this.workgroupSize, [
34
+ this.workPerThread,
35
+ 1,
36
+ 1
37
+ ]), this.start = e, this.uniforms = `start : ${c(e.length)}, `, this.shaderKey = "slice";
38
+ }
39
+ getUserCode() {
40
+ const e = c(this.rank);
41
+ let t;
42
+ return this.start.length === 1 ? t = this.outputShape.map(() => "sourceLoc = uniforms.start + coords;") : t = this.outputShape.map((r, s) => `sourceLoc.${p[s]} = uniforms.start.${f(s)} + coords.${p[s]};`), `
43
+ ${l("index")} {
44
+ if (index < uniforms.size) {
45
+ var sourceLoc : ${e};
46
+ let coords = getCoordsFromIndex(index);
47
+ ${t.join(`
48
+ `)}
49
+ result[index] = source[getIndexFromCoords${b(this.rank)}(sourceLoc, uniforms.sourceShape)];
50
+ }
51
+ }
52
+ `;
53
+ }
54
+ }
55
+ const p = ["x", "y", "z", "w", "u", "v"];
56
+ function C(o) {
57
+ const { inputs: e, backend: t, attrs: n } = o, { x: r } = e, { begin: s, size: h } = n, [i, a] = y(r, s, h);
58
+ if ($(r, i, a), x(a) === 0)
59
+ return t.makeTensorInfo(a, r.dtype, []);
60
+ const d = new w(i, a), m = [{ type: "int32", data: i }], u = t.runWebGPUProgram(d, [r], r.dtype, m);
61
+ return u.packed = !0, u;
62
+ }
63
+ const D = {
64
+ kernelName: "Slice16",
65
+ backendName: "webgpu",
66
+ kernelFunc: C
67
+ };
68
+ k(D);
69
+ export {
70
+ C as slice
71
+ };
@@ -1,7 +1,6 @@
1
- import { G as i } from "./index-BzFyqcy-.js";
2
1
  /**
3
2
  * @license
4
- * Copyright 2018 Google LLC. All Rights Reserved.
3
+ * Copyright 2023 Google LLC.
5
4
  * Licensed under the Apache License, Version 2.0 (the "License");
6
5
  * you may not use this file except in compliance with the License.
7
6
  * You may obtain a copy of the License at
@@ -15,9 +14,4 @@ import { G as i } from "./index-BzFyqcy-.js";
15
14
  * limitations under the License.
16
15
  * =============================================================================
17
16
  */
18
- function m(r, a = !0, e, t) {
19
- return i.makeVariable(r, a, e, t);
20
- }
21
- export {
22
- m as v
23
- };
17
+ export {};
@@ -0,0 +1,23 @@
1
+ import { e as S } from "../../index-ZyQhjEPo.js";
2
+ import { reshape16 as h } from "../reshape16.js";
3
+ import b from "./softmax16_program.js";
4
+ import d from "./softmax16_subgroup_program.js";
5
+ import x from "./utils/deviceInfo.js";
6
+ import { r as k } from "../../tensor_util-DV-FP5Q3.js";
7
+ import { r as l } from "../../reshape-DevtBWtf.js";
8
+ import { s as z } from "../../tensor-DdQUJZlz.js";
9
+ function I(a) {
10
+ const { inputs: m, backend: e, attrs: p } = a, { logits: o } = m, { dim: s } = p, i = e.subgroupMinSize, c = e.subgroupMaxSize, u = x(e).subgroupsSupported, r = l(o, [
11
+ z(o.shape) / o.shape[s],
12
+ o.shape[s]
13
+ ]), f = u ? new d(r.shape, i, c) : new b(r.shape), n = e.runWebGPUProgram(f, [r], "int32");
14
+ n.packed = !0, r.dispose();
15
+ const t = S().makeTensorFromTensorInfo(n), g = h(t, o.shape);
16
+ return t.dispose(), g;
17
+ }
18
+ const P = {
19
+ kernelName: "Softmax16",
20
+ backendName: "webgpu",
21
+ kernelFunc: I
22
+ };
23
+ k(P);
@@ -0,0 +1,13 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
2
+ export default class SoftmaxProgram implements WebGPUProgram {
3
+ variableNames: string[];
4
+ outputShape: number[];
5
+ shaderKey: string;
6
+ dispatchLayout: {
7
+ x: number[];
8
+ };
9
+ dispatch: [number, number, number];
10
+ workgroupSize: [number, number, number];
11
+ constructor(outputShape: number[]);
12
+ getUserCode(): string;
13
+ }
@@ -0,0 +1,73 @@
1
+ import { e } from "../../webgpu_program-Cigz-7RF.js";
2
+ import { f as o } from "../../webgpu_util-BBCnKm2X.js";
3
+ class i {
4
+ variableNames = ["logits"];
5
+ outputShape;
6
+ shaderKey;
7
+ dispatchLayout;
8
+ dispatch;
9
+ workgroupSize;
10
+ constructor(r) {
11
+ this.outputShape = r, this.dispatchLayout = o(this.outputShape), this.dispatch = [this.outputShape[0], 1, 1], this.outputShape[1] >= 4096 ? this.workgroupSize = [256, 1, 1] : this.outputShape[1] < 64 ? this.workgroupSize = [32, 1, 1] : this.workgroupSize = [64, 1, 1], this.shaderKey = "softmax16";
12
+ }
13
+ getUserCode() {
14
+ return `
15
+ var<workgroup> buf : array<f32, ${this.workgroupSize[0]}>;
16
+ const blockSize = ${this.workgroupSize[0]};
17
+ ${e("index")} {
18
+ let row = index / blockSize;
19
+ let tid = i32(localId.x);
20
+ let cols = uniforms.outShape[1];
21
+ let rowIdx = row * cols;
22
+
23
+ var threadMax = -3.402823e+38f;
24
+ for (var col = tid; col < cols; col += blockSize) {
25
+ let value = unpack2x16float(u32(logits[rowIdx + col]));
26
+ threadMax = max(threadMax, max(value.x, value.y));
27
+ }
28
+ buf[tid] = threadMax;
29
+ workgroupBarrier();
30
+
31
+ for (var currSize = blockSize >> 1; currSize > 0; currSize = currSize >> 1) {
32
+ if (tid < currSize) {
33
+ buf[tid] = max(buf[tid], buf[tid + currSize]);
34
+ }
35
+ workgroupBarrier();
36
+ }
37
+
38
+ let rowMaxShared: f32 = buf[0];
39
+ workgroupBarrier();
40
+
41
+ var threadSum = 0.0f;
42
+ for (var col = tid; col < cols; col += blockSize) {
43
+ let value = unpack2x16float(u32(logits[rowIdx + col]));
44
+ let subExp = exp(value.x - rowMaxShared);
45
+ threadSum += subExp;
46
+ let subExpY = exp(value.y - rowMaxShared);
47
+ threadSum += subExpY;
48
+ }
49
+ buf[tid] = threadSum;
50
+ workgroupBarrier();
51
+
52
+ for (var currSize = blockSize >> 1; currSize > 0; currSize = currSize >> 1) {
53
+ if (tid < currSize) {
54
+ buf[tid] = buf[tid] + buf[tid + currSize];
55
+ }
56
+ workgroupBarrier();
57
+ }
58
+
59
+ let rowSumShared: f32 = buf[0];
60
+
61
+ for (var col = tid; col < cols; col += blockSize) {
62
+ let value = unpack2x16float(u32(logits[rowIdx + col]));
63
+ let value1: f32 = exp(value.x - rowMaxShared) / rowSumShared;
64
+ let value2: f32 = exp(value.y - rowMaxShared) / rowSumShared;
65
+ result[rowIdx + col] = i32(pack2x16float(vec2<f32>(value1, value2)));
66
+ }
67
+ }
68
+ `;
69
+ }
70
+ }
71
+ export {
72
+ i as default
73
+ };
@@ -0,0 +1,17 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
2
+ export default class SoftmaxSubgroupProgram implements WebGPUProgram {
3
+ variableNames: string[];
4
+ outputShape: number[];
5
+ shaderKey: string;
6
+ dispatchLayout: {
7
+ x: number[];
8
+ };
9
+ dispatch: [number, number, number];
10
+ workgroupSize: [number, number, number];
11
+ minSubgroupSize: number;
12
+ maxSubgroupSize: number;
13
+ subgroups: boolean;
14
+ subgroupBuiltins: boolean;
15
+ constructor(outputShape: number[], minSubgroupSize: number, maxSubgroupSize: number);
16
+ getUserCode(): string;
17
+ }
@@ -0,0 +1,75 @@
1
+ import { e as o } from "../../webgpu_program-Cigz-7RF.js";
2
+ import { f as u } from "../../webgpu_util-BBCnKm2X.js";
3
+ class i {
4
+ variableNames = ["logits"];
5
+ outputShape;
6
+ shaderKey;
7
+ dispatchLayout;
8
+ dispatch;
9
+ workgroupSize;
10
+ minSubgroupSize;
11
+ maxSubgroupSize;
12
+ subgroups = !0;
13
+ subgroupBuiltins = !1;
14
+ constructor(e, a, r) {
15
+ this.minSubgroupSize = a, this.maxSubgroupSize = r, this.outputShape = e, this.dispatchLayout = u(this.outputShape), this.dispatch = [this.outputShape[0], 1, 1], this.workgroupSize = [Math.min(64, r), 1, 1], a !== r && (this.subgroupBuiltins = !0, this.workgroupSize = [64, 1, 1]), this.shaderKey = "softmax16subgroup";
16
+ }
17
+ getUserCode() {
18
+ const e = this.maxSubgroupSize !== this.minSubgroupSize;
19
+ return `
20
+ ${e ? `var<workgroup> bestValues : array<f32, ${this.workgroupSize[0]}>;` : ""}
21
+ const blockSize = ${this.workgroupSize[0]};
22
+ ${o("index")} {
23
+ let row = index / blockSize;
24
+ let tid = i32(localId.x);
25
+ let cols = uniforms.outShape[1];
26
+ let rowIdx = row * cols;
27
+
28
+ var threadMax = -3.402823e+38f;
29
+ for (var col = tid; col < cols; col += blockSize) {
30
+ let value = unpack2x16float(u32(logits[rowIdx + col]));
31
+ threadMax = max(threadMax, max(value.x, value.y));
32
+ }
33
+
34
+ threadMax = subgroupMax(threadMax);
35
+ ${e ? `
36
+ let lane = localId.x % subgroupSize;
37
+ if (lane == 0) {
38
+ bestValues[localId.x / subgroupSize] = threadMax;
39
+ }
40
+ workgroupBarrier();
41
+ let numSubgroups = blockSize / subgroupSize;
42
+ threadMax = select(-3.402823e+38f, bestValues[lane], lane < numSubgroups);
43
+ threadMax = subgroupMax(threadMax);
44
+ workgroupBarrier();
45
+ ` : ""}
46
+
47
+ var threadSum = 0.0f;
48
+ for (var col = tid; col < cols; col += blockSize) {
49
+ let value = unpack2x16float(u32(logits[rowIdx + col]));
50
+ let subExp = exp(value - threadMax);
51
+ threadSum += subExp.x + subExp.y;
52
+ }
53
+
54
+ threadSum = subgroupAdd(threadSum);
55
+ ${e ? `
56
+ if (lane == 0) {
57
+ bestValues[localId.x / subgroupSize] = threadSum;
58
+ }
59
+ workgroupBarrier();
60
+ threadSum = select(0.0f, bestValues[lane], lane < numSubgroups);
61
+ threadSum = subgroupAdd(threadSum);
62
+ ` : ""}
63
+
64
+ for (var col = tid; col < cols; col += ${e ? "i32(subgroupSize)" : "blockSize"}) {
65
+ let value = unpack2x16float(u32(logits[rowIdx + col]));
66
+ let valuePair: vec2<f32> = exp(value - threadMax) / threadSum;
67
+ result[rowIdx + col] = i32(pack2x16float(valuePair));
68
+ }
69
+ }
70
+ `;
71
+ }
72
+ }
73
+ export {
74
+ i as default
75
+ };
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,38 @@
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { createReduceInfo as p, reduce as m, ReduceProgram as i } from "./utils/reductions.js";
3
+ import { isPackedTensor as n } from "../../utilities/packed.js";
4
+ import l from "./utils/deviceInfo.js";
5
+ import { r as k } from "../../tensor_util-DV-FP5Q3.js";
6
+ class x extends i {
7
+ constructor(e, t) {
8
+ super(e, t, { reductionOp: "sum", elementwise: !0 }, !0), this.shaderKey = "SoftmaxGrad16", this.variableNames = ["dy", "softmaxOutput"], this.variableComponents = [1, 1];
9
+ }
10
+ getReadSnippet() {
11
+ return `
12
+ let d: vec2<f32> = unpack2x16float(u32(dy[index]));
13
+ let l: vec2<f32> = unpack2x16float(u32(softmaxOutput[index]));
14
+ return d * l;
15
+ `;
16
+ }
17
+ getWriteSnippet() {
18
+ return `
19
+ let d: vec2<f32> = unpack2x16float(u32(dy[offset + k]));
20
+ let l: vec2<f32> = unpack2x16float(u32(softmaxOutput[offset + k]));
21
+ let outVal = l * (d - bestValue);
22
+ result[offset + k] = i32(pack2x16float(outVal));
23
+ `;
24
+ }
25
+ }
26
+ function b(o) {
27
+ const { dy: e, softmaxOutput: t } = o.inputs, r = o.backend, s = l(r), u = n(e), c = n(t);
28
+ if (!(u && c))
29
+ throw new Error("softmaxGradGPU: dy and softmaxOutput must be packed tensors");
30
+ const a = [e, t], f = p(a, -1), d = new x(s, f);
31
+ return m(d, a, r);
32
+ }
33
+ const v = {
34
+ kernelName: "Softmax16Grad",
35
+ backendName: "webgpu",
36
+ kernelFunc: b
37
+ };
38
+ k(v);
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,14 @@
1
+ import "../../index-ZyQhjEPo.js";
2
+ import { BinaryOpProgram as p } from "./utils/binary_op.js";
3
+ import { B as m } from "../../binary_op_util-pKXltfxI.js";
4
+ import { r as s } from "../../tensor_util-DV-FP5Q3.js";
5
+ function c(r) {
6
+ const { a: e, b: n } = r.inputs, t = r.backend, a = new p(m.SUB, e.shape, n.shape), o = t.runWebGPUProgram(a, [e, n], "int32");
7
+ return o.packed = !0, o;
8
+ }
9
+ const i = {
10
+ kernelName: "Sub16",
11
+ backendName: "webgpu",
12
+ kernelFunc: c
13
+ };
14
+ s(i);
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,40 @@
1
+ import { createReduceInfo as g, reduce as h, ReduceProgram as x } from "./utils/reductions.js";
2
+ import "../../index-ZyQhjEPo.js";
3
+ import { isPackedTensor as k } from "../../utilities/packed.js";
4
+ import { transpose16 as A } from "../transpose16.js";
5
+ import P from "./utils/deviceInfo.js";
6
+ import { r as b } from "../../tensor_util-DV-FP5Q3.js";
7
+ import { s as I } from "../../sum-_fzj5ZTB.js";
8
+ import { p as w } from "../../tensor-DdQUJZlz.js";
9
+ import { a as D, b as K } from "../../axis_util-BvHEw88j.js";
10
+ class v extends x {
11
+ shaderKey = "sum16";
12
+ constructor(e, o, t) {
13
+ super(
14
+ e,
15
+ o,
16
+ {
17
+ reductionOp: "sum",
18
+ elementwise: !1
19
+ },
20
+ t
21
+ ), t && (this.shaderKey += "_packed");
22
+ }
23
+ }
24
+ function y(r) {
25
+ const { x: e } = r.inputs, { axis: o, keepDims: t } = r.attrs, m = r.backend, a = [], p = P(m), c = k(e);
26
+ if (!c)
27
+ return I(e, o, t);
28
+ let n = w(o ?? -1, e.shape);
29
+ const i = D(n, e.shape.length);
30
+ let s = e;
31
+ i != null && (s = A(e, i), n = K(n.length, s.shape.length), a.push(s));
32
+ const u = g([s], -1), f = new v(p, u, c), d = h(f, [s], m);
33
+ return a.forEach((l) => l.dispose()), d;
34
+ }
35
+ const N = {
36
+ kernelName: "Sum16",
37
+ backendName: "webgpu",
38
+ kernelFunc: y
39
+ };
40
+ b(N);
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,35 @@
1
+ import { isPackedTensor as d } from "../../utilities/packed.js";
2
+ import { e as k } from "../../index-ZyQhjEPo.js";
3
+ import { reshape16 as i } from "../reshape16.js";
4
+ import l from "./transpose16_shared_program.js";
5
+ import P from "./transpose16_program.js";
6
+ import { r as b } from "../../tensor_util-DV-FP5Q3.js";
7
+ import { t as T } from "../../transpose-DKELTqhe.js";
8
+ function w(a) {
9
+ const { inputs: u, attrs: h } = a, { x: e } = u, { perm: r } = h, m = a.backend, c = d(e);
10
+ if (c && r[r.length - 1] !== e.shape.length - 1) {
11
+ const n = e.shape.length, t = n === 4 ? r.map((s) => s - 1).filter((s) => s >= 0) : r, p = n === 4 ? i(e, [e.shape[0] * e.shape[1], e.shape[2], e.shape[3]]) : e, f = new l(p.shape, t), o = m.runWebGPUProgram(f, [p], "int32");
12
+ if (o.packed = !0, n === 4) {
13
+ p.dispose();
14
+ const s = k().makeTensorFromTensorInfo(o), g = i(s, [
15
+ e.shape[0],
16
+ e.shape[1],
17
+ o.shape[1],
18
+ o.shape[2]
19
+ ]);
20
+ return s.dispose(), g;
21
+ }
22
+ return o;
23
+ }
24
+ if (c) {
25
+ const n = new P(e.shape, r), t = m.runWebGPUProgram(n, [e], "int32");
26
+ return t.packed = !0, t;
27
+ } else
28
+ return T(e, r);
29
+ }
30
+ const F = {
31
+ kernelName: "Transpose16",
32
+ backendName: "webgpu",
33
+ kernelFunc: w
34
+ };
35
+ b(F);
@@ -0,0 +1,16 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
2
+ export default class TransposeProgram16 implements WebGPUProgram {
3
+ variableNames: string[];
4
+ shaderKey: string;
5
+ outputShape: number[];
6
+ dispatchLayout: {
7
+ x: number[];
8
+ };
9
+ dispatch: [number, number, number];
10
+ workPerThread: number;
11
+ workgroupSize: [number, number, number];
12
+ newDim: number[];
13
+ size: boolean;
14
+ constructor(aShape: number[], newDim: number[]);
15
+ getUserCode(): string;
16
+ }
@@ -0,0 +1,50 @@
1
+ import { f as a, c as i } from "../../webgpu_util-BBCnKm2X.js";
2
+ import { b as h, e as d, c as n } from "../../webgpu_program-Cigz-7RF.js";
3
+ function p(r) {
4
+ const e = r.length;
5
+ if (e > 6)
6
+ throw Error(`Transpose for rank ${e} is not yet supported`);
7
+ const o = new Array(e);
8
+ for (let t = 0; t < r.length; t++)
9
+ o[r[t]] = `coords.${n(t)}`;
10
+ return o.join();
11
+ }
12
+ class l {
13
+ variableNames = ["A"];
14
+ shaderKey;
15
+ outputShape;
16
+ dispatchLayout;
17
+ dispatch;
18
+ workPerThread = 1;
19
+ workgroupSize = [64, 1, 1];
20
+ newDim;
21
+ size = !0;
22
+ constructor(e, o) {
23
+ const t = new Array(e.length);
24
+ for (let s = 0; s < t.length; s++)
25
+ t[s] = e[o[s]];
26
+ this.outputShape = t, this.dispatchLayout = a(this.outputShape), this.dispatch = i(this.dispatchLayout, this.outputShape, this.workgroupSize, [
27
+ this.workPerThread,
28
+ 1,
29
+ 1
30
+ ]), this.newDim = o, this.shaderKey = `transpose16_${o}`;
31
+ }
32
+ getUserCode() {
33
+ const e = h(this.outputShape.length), o = p(this.newDim);
34
+ return `
35
+ ${d("index")} {
36
+ for(var i = 0; i < ${this.workPerThread}; i = i + 1) {
37
+ let flatIndex = index * ${this.workPerThread} + i;
38
+ if(flatIndex < uniforms.size) {
39
+ let coords = getCoordsFromIndex(flatIndex);
40
+ result[flatIndex] = A[getIndexFromCoords${this.outputShape.length}D(
41
+ ${e}(${o}), uniforms.aShape)];
42
+ }
43
+ }
44
+ }
45
+ `;
46
+ }
47
+ }
48
+ export {
49
+ l as default
50
+ };
@@ -0,0 +1,15 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu/dist/webgpu_program';
2
+ export default class TransposeSharedProgram16 implements WebGPUProgram {
3
+ variableNames: string[];
4
+ outputShape: number[];
5
+ shaderKey: string;
6
+ dispatchLayout: {
7
+ x: number[];
8
+ y: number[];
9
+ z?: number[];
10
+ };
11
+ dispatch: [number, number, number];
12
+ workgroupSize: [number, number, number];
13
+ constructor(aShape: number[], newDim: number[]);
14
+ getUserCode(): string;
15
+ }
@@ -0,0 +1,71 @@
1
+ import { c as a } from "../../webgpu_util-BBCnKm2X.js";
2
+ import { e as p } from "../../webgpu_program-Cigz-7RF.js";
3
+ import "../../index-ZyQhjEPo.js";
4
+ import { a as l } from "../../tensor-DdQUJZlz.js";
5
+ class y {
6
+ variableNames = ["A"];
7
+ outputShape;
8
+ shaderKey;
9
+ dispatchLayout;
10
+ dispatch;
11
+ // Note that the maximum number of workgroup invocations by webgpu is 256.
12
+ // Nick: Reduce to 8x8
13
+ workgroupSize = [8, 8, 1];
14
+ constructor(t, o) {
15
+ const i = t.length, e = new Array(i), u = t.slice();
16
+ u[u.length - 1] *= 2;
17
+ for (let r = 0; r < e.length; r++)
18
+ e[r] = u[o[r]];
19
+ e[e.length - 1] /= 2, this.outputShape = e, this.dispatchLayout = i === 2 ? { x: [0], y: [1] } : { x: [1], y: [2], z: [0] }, this.dispatch = a(this.dispatchLayout, this.outputShape, this.workgroupSize, [2, 1, 1]), this.shaderKey = `transposeShared16_${i}`;
20
+ }
21
+ getUserCode() {
22
+ const t = this.outputShape.length;
23
+ l(
24
+ this.workgroupSize[0] === this.workgroupSize[1],
25
+ () => `Must be a square tile, current tile shape is ${this.workgroupSize[0]} x ${this.workgroupSize[1]}`
26
+ );
27
+ const o = this.workgroupSize[0] * 2;
28
+ return `
29
+ var<workgroup> tile : array<array<f32, ${o + 1}>, ${o}>;
30
+ ${p()} {
31
+ var x = i32(workgroupId.x) * ${o / 2} + i32(localId.x);
32
+ var y = i32(workgroupId.y) * ${o} + i32(localId.y);
33
+ let batch = ${t === 3 ? "i32(workgroupId.z)" : "0"};
34
+ let batchOffsetA = ${t === 3 ? "batch * uniforms.aShapeStrides[0]" : "0"};
35
+ let batchOffsetOut = ${t === 3 ? "batch * uniforms.outShapeStrides[0]" : "0"};
36
+
37
+ let inputWidth = uniforms.outShape[${t === 3 ? "1" : "0"}] / 2; // Output height
38
+ let inputHeight = uniforms.outShape[${t === 3 ? "2" : "1"}] * 2; // Output width
39
+ if (x < inputWidth && y < inputHeight) {
40
+ let unpackedA = unpack2x16float(u32(A[batchOffsetA + y * inputWidth + x]));
41
+ tile[localId.y][localId.x * 2] = unpackedA.x;
42
+ tile[localId.y][localId.x * 2 + 1] = unpackedA.y;
43
+ }
44
+ // Second load to cover the tile
45
+ y = y + ${this.workgroupSize[0]};
46
+ if (x < inputWidth && y < inputHeight) {
47
+ let unpackedA = unpack2x16float(u32(A[batchOffsetA + y * inputWidth + x]));
48
+ tile[localId.y + ${this.workgroupSize[0]}][localId.x * 2] = unpackedA.x;
49
+ tile[localId.y + ${this.workgroupSize[0]}][localId.x * 2 + 1] = unpackedA.y;
50
+ }
51
+ workgroupBarrier();
52
+
53
+ let outputWidth = uniforms.outShape[${t === 3 ? "2" : "1"}]; // Output width
54
+ let outputHeight = uniforms.outShape[${t === 3 ? "1" : "0"}] * 2; // Output height
55
+ x = i32(workgroupId.y) * ${o / 2} + i32(localId.x);
56
+ y = i32(workgroupId.x) * ${o} + i32(localId.y);
57
+ if (x < outputWidth && y < outputHeight) {
58
+ result[batchOffsetOut + y * outputWidth + x] = i32(pack2x16float(vec2<f32>(tile[localId.x * 2][localId.y], tile[localId.x * 2 + 1][localId.y])));
59
+ }
60
+ // Second store to cover the tile
61
+ y = y + ${this.workgroupSize[0]};
62
+ if (x < outputWidth && y < outputHeight) {
63
+ result[batchOffsetOut + y * outputWidth + x] = i32(pack2x16float(vec2<f32>(tile[localId.x * 2][localId.y + ${this.workgroupSize[0]}], tile[localId.x * 2 + 1][localId.y + ${this.workgroupSize[0]}])));
64
+ }
65
+ }
66
+ `;
67
+ }
68
+ }
69
+ export {
70
+ y as default
71
+ };
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,49 @@
1
+ import { f as c, c as r } from "../../webgpu_util-BBCnKm2X.js";
2
+ import { e as u } from "../../webgpu_program-Cigz-7RF.js";
3
+ import "../../index-ZyQhjEPo.js";
4
+ import { r as p } from "../../tensor_util-DV-FP5Q3.js";
5
+ class l {
6
+ outputShape;
7
+ shaderKey = "Unpack16";
8
+ dispatchLayout;
9
+ dispatch;
10
+ workgroupSize = [64, 1, 1];
11
+ variableNames = ["x"];
12
+ size = !0;
13
+ uniforms;
14
+ outputComponent = 4;
15
+ variableComponents = [2];
16
+ scaling = !1;
17
+ constructor(t) {
18
+ this.outputShape = [...t.slice(0, -1), t[t.length - 1] * 2], this.dispatchLayout = c(this.outputShape), this.dispatch = r(this.dispatchLayout, this.outputShape, this.workgroupSize, [4, 1, 1]);
19
+ }
20
+ useScaling() {
21
+ this.shaderKey += "_Scaled", this.uniforms = "scaling : f32,", this.scaling = !0;
22
+ }
23
+ getUserCode() {
24
+ return `
25
+ ${u("index")} {
26
+ let outIndex = index;
27
+ if (outIndex < uniforms.size) {
28
+ let xvec2 = x[index];
29
+ let v1 = vec4<f32>(
30
+ unpack2x16float(u32(xvec2.x)),
31
+ unpack2x16float(u32(xvec2.y))
32
+ ) ${this.scaling ? "* uniforms.scaling" : ""};
33
+ result[outIndex] = v1;
34
+ }
35
+ }`;
36
+ }
37
+ }
38
+ function h(e) {
39
+ const { x: t } = e.inputs, { scaling: a } = e.attrs, i = e.backend, n = new l(t.shape), s = a !== 1;
40
+ s && n.useScaling();
41
+ const o = [{ type: "float32", data: [1 / a] }];
42
+ return i.runWebGPUProgram(n, [t], "float32", s ? o : void 0);
43
+ }
44
+ const d = {
45
+ kernelName: "Unpack16",
46
+ backendName: "webgpu",
47
+ kernelFunc: h
48
+ };
49
+ p(d);
@@ -0,0 +1,19 @@
1
+ import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu';
2
+ import { BinaryOpType } from '@tensorflow/tfjs-backend-webgpu/dist/binary_op_util';
3
+ export { BinaryOpType };
4
+ export declare class BinaryOpProgram implements WebGPUProgram {
5
+ dispatch: [number, number, number];
6
+ dispatchLayout: {
7
+ x: number[];
8
+ };
9
+ outputComponent: number;
10
+ op: BinaryOpType;
11
+ outputShape: number[];
12
+ shaderKey: string;
13
+ size: boolean;
14
+ variableNames: string[];
15
+ workgroupSize: [number, number, number];
16
+ variableComponents: number[];
17
+ constructor(op: BinaryOpType, aShape: number[], bShape: number[]);
18
+ getUserCode(): string;
19
+ }