@genai-fi/nanogpt 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +82 -64
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -0,0 +1,547 @@
1
+ import { e as D, J as W } from "./index-ZyQhjEPo.js";
2
+ import { e as g, a as _, z as O, s as x, A as $, B as F, C as K, r as Z, t as j, g as X, i as q } from "./tensor-DdQUJZlz.js";
3
+ import { m as J, f as ee, P as te } from "./webgpu_program-Cigz-7RF.js";
4
+ import { i as se, G as N } from "./webgpu_util-BBCnKm2X.js";
5
+ import { K as re, J as ne } from "./tensor_util-DV-FP5Q3.js";
6
+ import { m as k } from "./complex_util-Yc1A_gV1.js";
7
+ const l = g();
8
+ l.registerFlag("WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE", () => 15);
9
+ l.registerFlag("WEBGPU_CPU_FORWARD", () => !0);
10
+ l.registerFlag("WEBGPU_MATMUL_PROGRAM_TYPE", () => -1);
11
+ l.registerFlag("WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE", () => !0);
12
+ l.registerFlag("WEBGPU_USE_LOW_POWER_GPU", () => !1);
13
+ l.registerFlag("WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD", () => 1e3);
14
+ l.registerFlag("WEBGPU_USE_PROFILE_TOOL", () => !1);
15
+ l.registerFlag("WEBGPU_IMPORT_EXTERNAL_TEXTURE", () => !0);
16
+ l.registerFlag("WEBGPU_USE_NAIVE_CONV2D_DEBUG", () => !1);
17
+ l.registerFlag("WEBGPU_THRESHOLD_TO_INCREASE_WORKGROUPS_FOR_MATMUL", () => -1);
18
+ l.registerFlag("WEBGPU_CONV_SEPARATE_IM2COL_SHADER", () => !1);
19
+ l.registerFlag("WEBGPU_PRINT_SHADER", () => "");
20
+ l.registerFlag("WEBGPU_ENGINE_COMPILE_ONLY", () => !1);
21
+ class ae {
22
+ constructor(e) {
23
+ e && (this.vendor = e.vendor, this.architecture = e.architecture, this.intelGPUGeneration = this.getIntelGPUGeneration());
24
+ }
25
+ getIntelGPUGeneration() {
26
+ if (this.isIntel()) {
27
+ if (this.architecture.startsWith("gen"))
28
+ return Number(this.architecture.match(/\d+/));
29
+ if (this.architecture.startsWith("xe"))
30
+ return 12;
31
+ }
32
+ return 0;
33
+ }
34
+ isIntel() {
35
+ return this.vendor === "intel";
36
+ }
37
+ }
38
+ class ie {
39
+ constructor(e) {
40
+ this.device = e, this.numUsedBuffers = 0, this.numFreeBuffers = 0, this.freeBuffers = /* @__PURE__ */ new Map(), this.usedBuffers = /* @__PURE__ */ new Map(), this.numBytesUsed = 0, this.numBytesAllocated = 0;
41
+ }
42
+ acquireBuffer(e, t, s = !1, n = !0) {
43
+ let r;
44
+ const a = z(e, t);
45
+ return n ? (this.freeBuffers.has(a) || this.freeBuffers.set(a, []), this.freeBuffers.get(a).length > 0 ? (r = this.freeBuffers.get(a).pop(), this.numFreeBuffers--) : (r = this.device.createBuffer({ size: e, usage: t, mappedAtCreation: s }), this.numBytesAllocated += e)) : (r = this.device.createBuffer({ size: e, usage: t, mappedAtCreation: s }), this.numBytesAllocated += e), this.usedBuffers.has(a) || this.usedBuffers.set(a, []), this.usedBuffers.get(a).push(r), this.numUsedBuffers++, this.numBytesUsed += e, r;
46
+ }
47
+ releaseBuffer(e, t = !0) {
48
+ if (this.freeBuffers.size === 0)
49
+ return;
50
+ const s = e.size, n = e.usage, r = z(s, n), a = this.usedBuffers.get(r), i = a.indexOf(e);
51
+ if (i < 0)
52
+ throw new Error("Cannot find the buffer in buffer manager");
53
+ a[i] = a[a.length - 1], a.pop(), this.numUsedBuffers--, this.numBytesUsed -= s, t ? (this.freeBuffers.get(r).push(e), this.numFreeBuffers++) : (e.destroy(), this.numBytesAllocated -= s);
54
+ }
55
+ getNumUsedBuffers() {
56
+ return this.numUsedBuffers;
57
+ }
58
+ getNumFreeBuffers() {
59
+ return this.numFreeBuffers;
60
+ }
61
+ dispose() {
62
+ this.freeBuffers.forEach((e, t) => {
63
+ e.forEach((s) => {
64
+ s.destroy();
65
+ });
66
+ }), this.usedBuffers.forEach((e, t) => {
67
+ e.forEach((s) => {
68
+ s.destroy();
69
+ });
70
+ }), this.freeBuffers = /* @__PURE__ */ new Map(), this.usedBuffers = /* @__PURE__ */ new Map(), this.numUsedBuffers = 0, this.numFreeBuffers = 0, this.numBytesUsed = 0, this.numBytesAllocated = 0;
71
+ }
72
+ }
73
+ function z(d, e) {
74
+ return `${d}_${e}`;
75
+ }
76
+ class oe {
77
+ constructor(e) {
78
+ this.device = e, this.numUsedTextures = 0, this.numFreeTextures = 0, this.freeTextures = /* @__PURE__ */ new Map(), this.usedTextures = /* @__PURE__ */ new Map(), this.numBytesUsed = 0, this.numBytesAllocated = 0;
79
+ }
80
+ acquireTexture(e, t, s, n) {
81
+ const r = Q(s), a = e * t * r, i = L(e, t, s, n);
82
+ if (this.freeTextures.has(i) || this.freeTextures.set(i, []), this.usedTextures.has(i) || this.usedTextures.set(i, []), this.numBytesUsed += a, this.numUsedTextures++, this.freeTextures.get(i).length > 0) {
83
+ this.numFreeTextures--;
84
+ const o = this.freeTextures.get(i).shift();
85
+ return this.usedTextures.get(i).push(o), o;
86
+ }
87
+ this.numBytesAllocated += a;
88
+ const u = this.device.createTexture({
89
+ size: [e, t],
90
+ format: s,
91
+ usage: n
92
+ });
93
+ return this.usedTextures.get(i).push(u), u;
94
+ }
95
+ releaseTexture(e) {
96
+ if (this.freeTextures.size === 0)
97
+ return;
98
+ const t = e.width, s = e.height, n = e.format, r = e.usage, a = L(t, s, n, r);
99
+ this.freeTextures.has(a) || this.freeTextures.set(a, []), this.freeTextures.get(a).push(e), this.numFreeTextures++, this.numUsedTextures--;
100
+ const i = this.usedTextures.get(a), u = i.indexOf(e);
101
+ if (u < 0)
102
+ throw new Error("Cannot release a texture that was never provided by this texture manager");
103
+ i.splice(u, 1);
104
+ const o = Q(n), f = t * s * o;
105
+ this.numBytesUsed -= f;
106
+ }
107
+ getNumUsedTextures() {
108
+ return this.numUsedTextures;
109
+ }
110
+ getNumFreeTextures() {
111
+ return this.numFreeTextures;
112
+ }
113
+ dispose() {
114
+ this.freeTextures.forEach((e, t) => {
115
+ e.forEach((s) => {
116
+ s.destroy();
117
+ });
118
+ }), this.usedTextures.forEach((e, t) => {
119
+ e.forEach((s) => {
120
+ s.destroy();
121
+ });
122
+ }), this.freeTextures = /* @__PURE__ */ new Map(), this.usedTextures = /* @__PURE__ */ new Map(), this.numUsedTextures = 0, this.numFreeTextures = 0, this.numBytesUsed = 0, this.numBytesAllocated = 0;
123
+ }
124
+ }
125
+ function L(d, e, t, s) {
126
+ return `${d}_${e}_${t}_${s}`;
127
+ }
128
+ function Q(d) {
129
+ if (d === "rgba8unorm")
130
+ return 16;
131
+ throw new Error(`${d} is not supported!`);
132
+ }
133
+ const ue = g().getNumber("WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD"), fe = (d, e) => {
134
+ const t = d.limits.maxComputeWorkgroupsPerDimension, s = e.dispatchLayout, n = e.dispatch;
135
+ if (n.every((a) => a <= t))
136
+ return n;
137
+ _(n[0] > t && s.y === void 0 && s.z === void 0, () => "Dispatch size exceeds WebGPU limits in Y or Z dimension.");
138
+ let r = Math.ceil(Math.sqrt(n[0]));
139
+ return r > t ? (r = Math.ceil(Math.cbrt(n[0])), _(r <= t, () => "Total dispatch size exceeds WebGPU maximum."), [r, r, r]) : [r, r, 1];
140
+ };
141
+ class R extends re {
142
+ nextDataId() {
143
+ return R.nextDataId++;
144
+ }
145
+ constructor(e, t) {
146
+ if (super(), this.commandQueueOwnedIds = /* @__PURE__ */ new WeakSet(), this.dispatchCountInPass = 0, this.disposed = !1, this.downloadWaitMs = 0, this.tensorDataPendingDisposal = [], this.queryResolveBuffer = null, this.querySet = null, this.querySetCount = 2, this.stagingPendingDisposal = [], this.uniformPendingDisposal = [], this.uploadWaitMs = 0, this.hasReadSyncWarned = !1, this.hasTimestampQueryWarned = !1, !se())
147
+ throw new Error("WebGPU is not supported on this device");
148
+ this.pipelineCache = {}, this.device = e, this.queue = e.queue, this.commandEncoder = null, this.computePassEncoder = null, this.adapterInfo = new ae(t), this.supportTimestampQuery = this.device.features.has("timestamp-query"), this.thresholdToIncreaseWorkgroups = this.adapterInfo.intelGPUGeneration >= 12 ? 16 : 8, this.bufferManager = new ie(this.device), this.textureManager = new oe(this.device), this.tensorMap = new ne(this, D()), g().getBool("WEBGPU_USE_PROFILE_TOOL") && (this.dummyCanvas = document.createElement("canvas"), this.dummyCanvas.width = 1, this.dummyCanvas.height = 1, this.dummyContext = this.dummyCanvas.getContext("webgpu"), this.dummyContext.configure({
149
+ device: e,
150
+ format: "bgra8unorm"
151
+ }), document.body.appendChild(this.dummyCanvas));
152
+ }
153
+ floatPrecision() {
154
+ return 32;
155
+ }
156
+ /**
157
+ * Dispose the memory if the dataId has 0 refCount. Return true if the memory
158
+ * is released or delayed in this backend, false if there are still
159
+ * references.
160
+ * @param dataId
161
+ * @oaram force Optional, remove the data regardless of refCount
162
+ */
163
+ disposeData(e, t = !1) {
164
+ if (!this.tensorMap.has(e))
165
+ return !0;
166
+ const s = this.tensorMap.get(e);
167
+ return t ? s.refCount = 0 : s.refCount--, s.refCount > 0 ? !1 : (s.complexTensorInfos != null && (this.disposeData(s.complexTensorInfos.real.dataId), this.disposeData(s.complexTensorInfos.imag.dataId)), this.commandQueueOwnedIds.has(e) ? (this.tensorDataPendingDisposal.push(e), !0) : (this.releaseResource(e), this.tensorMap.delete(e), !0));
168
+ }
169
+ memory() {
170
+ return {
171
+ numBytesInGPU: this.bufferManager.numBytesUsed,
172
+ numBytesAllocatedInGPU: this.bufferManager.numBytesAllocated,
173
+ unreliable: !1
174
+ };
175
+ }
176
+ releaseResource(e) {
177
+ const t = this.tensorMap.get(e);
178
+ if (!(!t || !t.resource)) {
179
+ if (t.external) {
180
+ t.resource = null;
181
+ return;
182
+ }
183
+ t.resource instanceof GPUBuffer ? this.bufferManager.releaseBuffer(t.resource) : t.resource instanceof GPUTexture && this.textureManager.releaseTexture(t.resource), t.resource = null;
184
+ }
185
+ }
186
+ /** Return refCount of a `TensorData`. */
187
+ refCount(e) {
188
+ return this.tensorMap.has(e) ? this.tensorMap.get(e).refCount : 0;
189
+ }
190
+ /** Increase refCount of a `TensorData`. */
191
+ incRef(e) {
192
+ const t = this.tensorMap.get(e);
193
+ t.refCount++;
194
+ }
195
+ /** Decrease refCount of a `TensorData`. */
196
+ decRef(e) {
197
+ if (this.tensorMap.has(e)) {
198
+ const t = this.tensorMap.get(e);
199
+ t.refCount--;
200
+ }
201
+ }
202
+ write(e, t, s) {
203
+ if (s === "complex64" && e != null)
204
+ throw new Error("Cannot write to a complex64 dtype. Please use tf.complex(real, imag).");
205
+ const n = { id: this.nextDataId() };
206
+ return this.tensorMap.set(n, { dtype: s, shape: t, values: e, refCount: 1 }), n;
207
+ }
208
+ move(e, t, s, n, r) {
209
+ if (n === "complex64")
210
+ throw new Error("Cannot write to a complex64 dtype. Please use tf.complex(real, imag).");
211
+ this.tensorMap.set(e, { dtype: n, shape: s, values: t, refCount: r });
212
+ }
213
+ submitQueue() {
214
+ this.queue.submit([this.commandEncoder.finish()]), this.commandEncoder = null, this.dispatchCountInPass = 0, this.commandQueueOwnedIds = /* @__PURE__ */ new WeakSet(), this.tensorDataPendingDisposal.forEach((e) => {
215
+ this.releaseResource(e), this.tensorMap.delete(e);
216
+ }), this.uniformPendingDisposal.forEach((e) => this.bufferManager.releaseBuffer(e)), this.stagingPendingDisposal.forEach((e) => this.bufferManager.releaseBuffer(e, !1)), this.tensorDataPendingDisposal = [], this.uniformPendingDisposal = [], this.stagingPendingDisposal = [];
217
+ }
218
+ ensureCommandEncoderReady() {
219
+ this.commandEncoder || (this.commandEncoder = this.device.createCommandEncoder());
220
+ }
221
+ endComputePassEncoder() {
222
+ this.computePassEncoder && (this.computePassEncoder.end(), this.computePassEncoder = null);
223
+ }
224
+ // Check if parallel compilation is done.
225
+ async checkCompileCompletionAsync() {
226
+ let e;
227
+ try {
228
+ e = await Promise.all(Object.values(this.pipelineCache));
229
+ } catch (t) {
230
+ throw new Error(t.message);
231
+ }
232
+ Object.keys(this.pipelineCache).map((t, s) => {
233
+ this.pipelineCache[t] = e[s];
234
+ });
235
+ }
236
+ async getBufferData(e) {
237
+ if (g().getBool("WEBGPU_ENGINE_COMPILE_ONLY"))
238
+ return console.warn("The data may be invalid since WEBGPU_ENGINE_COMPILE_ONLY is true, this can only be called when WEBGPU_ENGINE_COMPILE_ONLY is false"), null;
239
+ const t = e.size, s = this.bufferManager.acquireBuffer(t, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
240
+ this.ensureCommandEncoderReady(), this.endComputePassEncoder(), this.commandEncoder.copyBufferToBuffer(e, 0, s, 0, t), this.submitQueue(), await s.mapAsync(GPUMapMode.READ);
241
+ const n = s.getMappedRange().slice(0);
242
+ return s.unmap(), s != null && this.bufferManager.releaseBuffer(s), g().getBool("WEBGPU_USE_PROFILE_TOOL") && (_(this.dummyContext !== void 0, () => "Fail to get context for profiling tool"), this.dummyContext.getCurrentTexture()), n;
243
+ }
244
+ convertAndCacheOnCPU(e, t) {
245
+ const s = this.tensorMap.get(e);
246
+ return s.values = t, s.values;
247
+ }
248
+ readSync(e) {
249
+ const t = this.tensorMap.get(e), { values: s, complexTensorInfos: n } = t;
250
+ if (s != null || t.dtype === "string")
251
+ return s;
252
+ if (t.dtype === "complex64") {
253
+ const E = this.readSync(n.real.dataId), B = this.readSync(n.imag.dataId), y = O(k(E, B).buffer, "float32");
254
+ return this.convertAndCacheOnCPU(e, y), y;
255
+ }
256
+ this.hasReadSyncWarned || (this.hasReadSyncWarned = !0, console.warn("The performance of synchronously reading data from GPU to CPU is poor on the webgpu backend, please use asynchronous APIs instead."));
257
+ const r = ["opaque", "premultiplied"], a = t.resource, i = a.size;
258
+ _(i % 4 === 0, () => "Because there is 4 bytes for one pixel, buffer size must be multiple of 4.");
259
+ const u = i / 4, o = new ArrayBuffer(i), f = 256, c = 256, h = r.map((E) => new OffscreenCanvas(f, c)), m = new OffscreenCanvas(f, c);
260
+ this.endComputePassEncoder(), h.map((E, B) => {
261
+ const y = E.getContext("webgpu");
262
+ return y.configure({
263
+ device: this.device,
264
+ format: "bgra8unorm",
265
+ usage: GPUTextureUsage.COPY_DST,
266
+ alphaMode: r[B]
267
+ }), y.getCurrentTexture();
268
+ }).map((E, B) => {
269
+ const y = f * 4, b = (P, S, v) => {
270
+ this.ensureCommandEncoderReady(), this.commandEncoder.copyBufferToTexture({
271
+ buffer: a,
272
+ bytesPerRow: y,
273
+ offset: v
274
+ }, {
275
+ texture: E
276
+ }, {
277
+ width: P,
278
+ height: S
279
+ }), this.submitQueue();
280
+ const I = m.getContext("2d", {
281
+ willReadFrequently: !0
282
+ });
283
+ I.clearRect(0, 0, P, S), I.drawImage(h[B], 0, 0);
284
+ const G = I.getImageData(0, 0, P, S).data, H = r[B], M = new Uint8ClampedArray(o, v, P * S * 4);
285
+ for (let p = 0; p < M.length; p += 4)
286
+ if (H === "premultiplied")
287
+ M[p + 3] = G[p + 3];
288
+ else {
289
+ const V = G[p];
290
+ M[p] = G[p + 2], M[p + 1] = G[p + 1], M[p + 2] = V;
291
+ }
292
+ }, Y = Math.floor(u / (f * c));
293
+ let T = f, U = c, C = 0;
294
+ for (let P = 0; P < Y; P++)
295
+ b(T, U, C), C += f * c * 4;
296
+ const A = u % (f * c);
297
+ U = Math.floor(A / f), U > 0 && (b(T, U, C), C += U * (f * 4)), T = A % f, T > 0 && b(T, 1, C);
298
+ });
299
+ const w = O(o, t.dtype);
300
+ return this.convertAndCacheOnCPU(e, w), w;
301
+ }
302
+ async read(e) {
303
+ if (!this.tensorMap.has(e))
304
+ throw new Error(`Tensor ${e} was not registered!`);
305
+ const t = this.tensorMap.get(e), { values: s } = t;
306
+ if (s != null)
307
+ return s;
308
+ let n;
309
+ if (t.dtype === "complex64") {
310
+ const r = await Promise.all([
311
+ this.read(t.complexTensorInfos.real.dataId),
312
+ this.read(t.complexTensorInfos.imag.dataId)
313
+ ]), a = r[0], i = r[1];
314
+ n = k(a, i);
315
+ } else {
316
+ const r = await this.getBufferData(t.resource);
317
+ n = O(r, t.dtype);
318
+ }
319
+ return this.convertAndCacheOnCPU(e, n), n;
320
+ }
321
+ // The source GPUBuffer and destination GPUBuffer have the same size and
322
+ // usage.
323
+ copyBuffer(e) {
324
+ const t = e.size, s = e.usage, n = this.bufferManager.acquireBuffer(t, s);
325
+ return this.ensureCommandEncoderReady(), this.endComputePassEncoder(), this.commandEncoder.copyBufferToBuffer(e, 0, n, 0, t), this.submitQueue(), n;
326
+ }
327
+ /**
328
+ * Create a TF.js tensor out of an existing WebGPU buffer.
329
+ */
330
+ createTensorFromGPUData(e, t, s) {
331
+ let n = e.buffer;
332
+ if (s === "complex64")
333
+ throw new Error("Cannot write to a complex64 dtype. ");
334
+ const r = { id: this.nextDataId() };
335
+ this.tensorMap.set(r, {
336
+ dtype: s,
337
+ shape: t,
338
+ values: null,
339
+ refCount: 1,
340
+ external: e.zeroCopy
341
+ });
342
+ const a = this.tensorMap.get(r), i = N(a.dtype) * x(a.shape);
343
+ if (e.buffer.size < i)
344
+ throw new Error(`GPUBuffer size(${e.buffer.size}) is smaller than tensor size(${i})!`);
345
+ if ((e.buffer.usage & (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC)) !== (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC))
346
+ throw new Error("GPUBuffer.usage should include GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC!");
347
+ return e.zeroCopy !== !0 && (n = this.copyBuffer(n)), a.resource = n, D().makeTensorFromDataId(r, t, s, this);
348
+ }
349
+ /**
350
+ * Read tensor to a new GPUBuffer.
351
+ * @param dataId The source tensor.
352
+ */
353
+ readToGPU(e) {
354
+ const t = this.tensorMap.get(e), { values: s, dtype: n, shape: r, resource: a } = t;
355
+ if (n === "complex64")
356
+ throw new Error("Does not support reading buffer for complex64 dtype.");
357
+ if (a == null)
358
+ throw s != null ? new Error("Data is not on GPU but on CPU.") : new Error("There is no data on GPU or CPU.");
359
+ const i = a, u = i.size, o = i.usage, f = this.bufferManager.acquireBuffer(u, o);
360
+ this.ensureCommandEncoderReady(), this.endComputePassEncoder(), this.commandEncoder.copyBufferToBuffer(a, 0, f, 0, u), this.submitQueue();
361
+ const c = this.makeTensorInfo(r, n), h = D().makeTensorFromTensorInfo(c), m = this.tensorMap.get(c.dataId);
362
+ return m.resource = f, { tensorRef: h, buffer: f };
363
+ }
364
+ bufferSync(e) {
365
+ const t = this.readSync(e.dataId);
366
+ if (e.dtype === "string")
367
+ try {
368
+ const s = t.map((n) => $(n));
369
+ return W(e.shape, e.dtype, s);
370
+ } catch {
371
+ throw new Error("Failed to decode encoded string bytes into utf-8");
372
+ }
373
+ return W(e.shape, e.dtype, t);
374
+ }
375
+ async time(e) {
376
+ !this.supportTimestampQuery && !this.hasTimestampQueryWarned && (console.warn("This device doesn't support timestamp-query extension. Start Chrome browser with flag --enable-dawn-features=allow_unsafe_apis to try it again. Otherwise, zero will be shown for the kernel time when profiling mode is enabled."), this.hasTimestampQueryWarned = !0);
377
+ const t = this.activeTimers, s = [];
378
+ let n = !1;
379
+ this.programTimersStack == null ? (this.programTimersStack = s, n = !0) : this.activeTimers.push(s), this.activeTimers = s, e();
380
+ const r = F(this.activeTimers.map((o) => o.query)).filter((o) => o != null), a = F(this.activeTimers.map((o) => o.name)).filter((o) => o != null);
381
+ this.activeTimers = t, n && (this.programTimersStack = null);
382
+ const i = {
383
+ uploadWaitMs: this.uploadWaitMs,
384
+ downloadWaitMs: this.downloadWaitMs,
385
+ kernelMs: null,
386
+ wallMs: null
387
+ }, u = await Promise.all(r);
388
+ return i.kernelMs = K(u), i.getExtraProfileInfo = () => u.map((o, f) => ({ name: a[f], ms: o })).map((o) => `${o.name}: ${o.ms}`).join(", "), this.uploadWaitMs = 0, this.downloadWaitMs = 0, i;
389
+ }
390
+ makeTensorInfo(e, t, s) {
391
+ return t === "string" && s != null && s.length > 0 && Z(s[0]) && (s = s.map((r) => j(r))), { dataId: this.write(s, e, t), shape: e, dtype: t };
392
+ }
393
+ tensorToBinding(e) {
394
+ if (!e)
395
+ return null;
396
+ const s = this.tensorMap.get(e.dataId).resource;
397
+ return s instanceof GPUBuffer ? { buffer: s } : s instanceof GPUTexture ? s.createView() : s;
398
+ }
399
+ uploadToGPU(e) {
400
+ const t = this.tensorMap.get(e);
401
+ if (t.resource != null)
402
+ return;
403
+ const s = N(t.dtype) * x(t.shape);
404
+ let n;
405
+ const r = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
406
+ if (t.values) {
407
+ if (n = this.bufferManager.acquireBuffer(s, r, !0), n.mapState === "unmapped") {
408
+ const a = this.bufferManager.acquireBuffer(s, GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC, !0, !1), i = a.getMappedRange();
409
+ t.dtype === "int32" || t.dtype === "bool" ? new Int32Array(i).set(t.values) : new Float32Array(i).set(t.values), a.unmap(), this.ensureCommandEncoderReady(), this.endComputePassEncoder(), this.commandEncoder.copyBufferToBuffer(a, 0, n, 0, s), this.stagingPendingDisposal.push(a);
410
+ } else {
411
+ const a = n.getMappedRange();
412
+ t.dtype === "int32" || t.dtype === "bool" ? new Int32Array(a).set(t.values) : new Float32Array(a).set(t.values), n.unmap();
413
+ }
414
+ t.values = null;
415
+ } else
416
+ n = this.bufferManager.acquireBuffer(s, r);
417
+ t.resource = n;
418
+ }
419
+ makeUniforms(e) {
420
+ let t = 0, s = 0;
421
+ const n = [];
422
+ let r = 1;
423
+ e.forEach((u) => {
424
+ u.data.length === 0 && (u.data = [1]);
425
+ let o;
426
+ switch (u.data.length) {
427
+ case 1:
428
+ o = 4;
429
+ break;
430
+ case 2:
431
+ o = 8;
432
+ break;
433
+ case 3:
434
+ o = 16;
435
+ break;
436
+ case 4:
437
+ o = 16;
438
+ break;
439
+ case 5:
440
+ o = 16;
441
+ break;
442
+ case 6:
443
+ o = 16;
444
+ break;
445
+ default:
446
+ _(!1, () => `Unsupported ${u.data.length}D shape`);
447
+ }
448
+ (s === 5 || s === 6) && (o = 16), o > r && (r = o), t = Math.ceil(t / o) * o, s = u.data.length, n.push(t), t += u.data.length * 4;
449
+ }), t = Math.ceil(t / r) * r;
450
+ const a = new ArrayBuffer(t);
451
+ e.forEach((u, o) => {
452
+ const f = n[o];
453
+ u.type === "int32" ? new Int32Array(a, f, u.data.length).set(u.data) : u.type === "uint32" ? new Uint32Array(a, f, u.data.length).set(u.data) : new Float32Array(a, f, u.data.length).set(u.data);
454
+ });
455
+ const i = this.bufferManager.acquireBuffer(t, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM);
456
+ return this.queue.writeBuffer(i, 0, a, 0, t), this.uniformPendingDisposal.push(i), { offset: 0, size: t, buffer: i };
457
+ }
458
+ runWebGPUProgram(e, t, s, n, r) {
459
+ if (r || (r = this.makeTensorInfo(e.outputShape, s)), x(r.shape) === 0)
460
+ return this.tensorMap.get(r.dataId).values = X(r.dtype, 0), r;
461
+ this.uploadToGPU(r.dataId), e.dispatch = fe(this.device, e);
462
+ const a = t.map((u, o) => {
463
+ if (u.dtype === "complex64")
464
+ throw new Error("GPGPUProgram does not support complex64 input. For complex64 dtypes, please separate the program into real and imaginary parts.");
465
+ return this.uploadToGPU(u.dataId), {
466
+ // Returning dtype from tensorMap because it reflects dtype
467
+ // of underlying buffer, rather than abstract dtype.
468
+ dtype: this.tensorMap.get(u.dataId).dtype,
469
+ shape: u.shape,
470
+ name: e.variableNames[o]
471
+ };
472
+ });
473
+ e.shaderKey = J(e, a, r);
474
+ const i = g().getBool("WEBGPU_ENGINE_COMPILE_ONLY");
475
+ return e.shaderKey in this.pipelineCache || (this.pipelineCache[e.shaderKey] = ee(this.device, e, a, r, i)), e.pipeline = this.pipelineCache[e.shaderKey], i || this.recordAndSubmit(e, r, t, n), r;
476
+ }
477
+ recordAndSubmit(e, t, s, n) {
478
+ if (e.pipeline instanceof Promise)
479
+ throw new Error("Please call checkCompileCompletionAsync to ensure parallel compilation is done!");
480
+ let r = [], a = [];
481
+ const i = "int32";
482
+ if (e.pixelsOpType == null) {
483
+ r.push({ type: "float32", data: [NaN] }, { type: "float32", data: [1 / 0] }), a = s.concat(t).map((m) => m.shape);
484
+ const h = "int32";
485
+ a.map((m) => {
486
+ r.push({ type: h, data: m });
487
+ const w = q(m);
488
+ r.push({ type: h, data: w });
489
+ });
490
+ } else {
491
+ const h = q(t.shape);
492
+ r.push({ type: i, data: h });
493
+ }
494
+ if (e.size) {
495
+ const h = x(e.outputShape);
496
+ r.push({
497
+ type: i,
498
+ data: [e.outputComponent ? h / e.outputComponent : h]
499
+ });
500
+ }
501
+ n && (r = [...r, ...n]);
502
+ const u = [
503
+ this.tensorToBinding(t),
504
+ ...s.map((h) => this.tensorToBinding(h)),
505
+ this.makeUniforms(r)
506
+ ];
507
+ s.forEach((h) => {
508
+ this.commandQueueOwnedIds.add(h.dataId);
509
+ }), this.commandQueueOwnedIds.add(t.dataId);
510
+ const o = this.device.createBindGroup({
511
+ layout: e.pipeline.getBindGroupLayout(0),
512
+ entries: u.map((h, m) => ({ binding: m, resource: h }))
513
+ }), f = this.activeTimers != null;
514
+ this.ensureCommandEncoderReady();
515
+ const c = {};
516
+ f && this.supportTimestampQuery ? (this.endComputePassEncoder(), this.querySet == null && (this.querySet = this.device.createQuerySet({
517
+ type: "timestamp",
518
+ count: this.querySetCount
519
+ })), c.timestampWrites = {
520
+ querySet: this.querySet,
521
+ beginningOfPassWriteIndex: 0,
522
+ endOfPassWriteIndex: 1
523
+ }, this.computePassEncoder = this.commandEncoder.beginComputePass(c)) : this.computePassEncoder || (this.computePassEncoder = this.commandEncoder.beginComputePass(c)), this.computePassEncoder.setPipeline(e.pipeline), this.computePassEncoder.setBindGroup(0, o), this.computePassEncoder.dispatchWorkgroups(e.dispatch[0], e.dispatch[1], e.dispatch[2]), this.dispatchCountInPass++, (f || g().get("WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE") <= this.dispatchCountInPass || e.pixelsOpType === te.DRAW) && (this.endComputePassEncoder(), f ? this.activeTimers.push({ name: e.constructor.name, query: this.getQueryTime() }) : this.submitQueue());
524
+ }
525
+ async getQueryTime() {
526
+ if (!this.supportTimestampQuery)
527
+ return 0;
528
+ this.queryResolveBuffer == null && (this.queryResolveBuffer = this.bufferManager.acquireBuffer(this.querySetCount * 8, GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.QUERY_RESOLVE)), this.commandEncoder.resolveQuerySet(this.querySet, 0, this.querySetCount, this.queryResolveBuffer, 0);
529
+ const e = this.bufferManager.acquireBuffer(this.querySetCount * 8, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST);
530
+ this.commandEncoder.copyBufferToBuffer(this.queryResolveBuffer, 0, e, 0, this.querySetCount * 8), this.submitQueue(), await e.mapAsync(GPUMapMode.READ);
531
+ const t = new BigUint64Array(e.getMappedRange()), s = Number(t[1] - t[0]) / 1e6;
532
+ return e.unmap(), this.bufferManager.releaseBuffer(e), s;
533
+ }
534
+ shouldExecuteOnCPU(e, t = ue) {
535
+ return g().getBool("WEBGPU_CPU_FORWARD") && e.every((s) => this.tensorMap.get(s.dataId).resource == null && x(s.shape) < t);
536
+ }
537
+ numDataIds() {
538
+ return this.tensorMap.numDataIds() - this.tensorDataPendingDisposal.length;
539
+ }
540
+ dispose() {
541
+ this.disposed || (this.querySet != null && this.querySet.destroy(), this.bufferManager.dispose(), this.textureManager.dispose(), this.disposed = !0);
542
+ }
543
+ }
544
+ R.nextDataId = 0;
545
+ export {
546
+ R as W
547
+ };