@genai-fi/nanogpt 0.9.1 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +55 -48
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -0,0 +1,401 @@
1
+ import { h as z } from "../index-ZyQhjEPo.js";
2
+ import { g as F, d as b, a as x, b as g, s as O, t as l, c as y } from "../webgpu_program-Cigz-7RF.js";
3
+ import { e as _, a as A, b as L } from "../tensor-DdQUJZlz.js";
4
+ var N = /* @__PURE__ */ ((s) => (s[s.FROM_PIXELS = 0] = "FROM_PIXELS", s[s.DRAW = 1] = "DRAW", s))(N || {});
5
+ const K = (s, t, e, o, u) => {
6
+ const a = { dtype: o.dtype, shape: o.shape }, n = D(e, a, t), r = s.createShaderModule({ code: n, label: t.constructor.name });
7
+ let d = _().get("WEBGPU_PRINT_SHADER");
8
+ if (d !== "") {
9
+ d = d.toLowerCase();
10
+ const f = d.split(",");
11
+ (d === "all" || f.some((v) => t.shaderKey.toLowerCase().includes(v))) && (console.group(t.shaderKey), console.debug(n), console.groupEnd());
12
+ }
13
+ return u ? s.createComputePipelineAsync({
14
+ compute: { module: r, entryPoint: "_start" },
15
+ label: t.constructor.name,
16
+ layout: "auto"
17
+ }) : s.createComputePipeline({
18
+ compute: { module: r, entryPoint: "_start" },
19
+ label: t.constructor.name,
20
+ layout: "auto"
21
+ });
22
+ };
23
+ function D(s, t, e) {
24
+ const o = [], u = e.workgroupSize[0] * e.workgroupSize[1] * e.workgroupSize[2];
25
+ if (e.outputComponent = e.outputComponent ? e.outputComponent : 1, o.push(`
26
+
27
+ var<private> localId: vec3<u32>;
28
+ var<private> localIndex: u32;
29
+ var<private> globalId: vec3<u32>;
30
+ var<private> numWorkgroups: vec3<u32>;
31
+ var<private> workgroupId: vec3<u32>;
32
+ ${e.subgroupBuiltins ? "var<private> subgroupInvocationId: u32;" : ""}
33
+ ${e.subgroupBuiltins ? "var<private> subgroupSize: u32;" : ""}
34
+
35
+ // Only used when the y/z dimension of workgroup size is 1.
36
+ fn getGlobalIndex() -> i32 {
37
+ ${R(e) ? " return i32(globalId.x);" : ` return i32((workgroupId.z * numWorkgroups.x * numWorkgroups.y +
38
+ workgroupId.y * numWorkgroups.x + workgroupId.x) * ${u}u +
39
+ localIndex);
40
+ `}
41
+ }
42
+ `), e.pixelsOpType != null) {
43
+ const i = e.pixelsOpType === 0 ? `@group(0) @binding(0) var<storage, read_write> result: array<${b(
44
+ t.dtype,
45
+ e.outputComponent
46
+ )}>;` : `@group(0) @binding(1) var<storage, read> inBuf : array<${b(
47
+ s[0].dtype,
48
+ e.outputComponent
49
+ )}>;`, h = t.shape.length === 3 ? "vec2<i32>" : "i32";
50
+ o.push(`
51
+ struct Uniform {
52
+ outShapeStrides : ${h},
53
+ size : i32,
54
+ numChannels : i32,
55
+ alpha : f32,
56
+ };
57
+
58
+ ${i}
59
+ @group(0) @binding(2) var<uniform> uniforms: Uniform;
60
+ `);
61
+ const I = w(e);
62
+ return [
63
+ C,
64
+ o.join(`
65
+ `),
66
+ x(t.shape),
67
+ e.getUserCode(),
68
+ m(I, e)
69
+ ].join(`
70
+ `);
71
+ }
72
+ let a, n, r = "struct Uniforms { NAN : f32, INFINITY : f32, ";
73
+ e.variableNames.forEach((i, h) => {
74
+ const I = g(s[h].shape.length);
75
+ r += `${i.charAt(0).toLowerCase() + i.slice(1)}Shape : ${I}, `, a = s[h].shape.length - 1, n = g(a), r += `${i.charAt(0).toLowerCase() + i.slice(1)}ShapeStrides: ${n}, `;
76
+ });
77
+ const d = g(t.shape.length);
78
+ r += `outShape : ${d}, `, a = t.shape.length - 1, n = g(a), r += `
79
+ outShapeStrides: ${n}, `, e.size && (r += "size : i32, "), e.uniforms && (r += e.uniforms), r += "};", r = P(r), o.push(r), e.atomic ? o.push(`
80
+ @group(0) @binding(0) var<storage, read_write> result: array<atomic<i32>>;
81
+ `) : o.push(`
82
+ @group(0) @binding(0) var<storage, read_write> result: array<${b(
83
+ t.dtype,
84
+ e.outputComponent
85
+ )}>;
86
+ `), e.variableNames.forEach((i, h) => {
87
+ o.push(`
88
+ @group(0) @binding(${1 + h}) var<storage, read> ${i}: array<${e.variableComponents ? b(s[h].dtype, e.variableComponents[h]) : b(s[h].dtype, e.outputComponent)}>;
89
+ `);
90
+ }), r !== "" && o.push(`
91
+ @group(0) @binding(${1 + e.variableNames.length}) var<uniform> uniforms: Uniforms;
92
+ `);
93
+ const f = B(t.shape, e.dispatchLayout), $ = [
94
+ e.subgroups ? "enable subgroups;" : "",
95
+ C,
96
+ o.join(`
97
+ `) + T,
98
+ x(t.shape),
99
+ f,
100
+ W(t.shape.length)
101
+ ];
102
+ e.atomic || $.push(G(t.shape, t.dtype, e.outputComponent)), e.variableNames.forEach((i, h) => {
103
+ $.push(`${x(s[h].shape, i)}`);
104
+ });
105
+ const c = s.map(
106
+ (i, h) => U(
107
+ i,
108
+ t.shape,
109
+ e.variableComponents ? e.variableComponents[h] : e.outputComponent,
110
+ e.dispatchLayout.x.length === t.shape.length
111
+ )
112
+ ).join(`
113
+ `);
114
+ $.push(c), $.push(e.getUserCode());
115
+ const p = w(e);
116
+ return $.push(m(p, e)), $.join(`
117
+ `);
118
+ }
119
+ const C = `
120
+ struct vec5 {x: i32, y: i32, z: i32, w: i32, u: i32};
121
+ struct vec6 {x: i32, y: i32, z: i32, w: i32, u: i32, v: i32};
122
+
123
+ // Checks whether coordinates lie within the bounds of the shape.
124
+ fn coordsInBounds2D(coord : vec2<i32>, shape : vec2<i32>) -> bool {
125
+ return all(coord >= vec2<i32>(0)) && all(coord < shape);
126
+ }
127
+ fn coordsInBounds3D(coord : vec3<i32>, shape : vec3<i32>) -> bool {
128
+ return all(coord >= vec3<i32>(0)) && all(coord < shape);
129
+ }
130
+ fn coordsInBounds4D(coord : vec4<i32>, shape : vec4<i32>) -> bool {
131
+ return all(coord >= vec4<i32>(0)) && all(coord < shape);
132
+ }
133
+
134
+ fn getIndexFromCoords1D(coord : i32, shape : i32) -> i32 {
135
+ return coord;
136
+ }
137
+ fn getIndexFromCoords2D(coords : vec2<i32>, shape : vec2<i32>) -> i32 {
138
+ return dot(coords, vec2<i32>(shape.y, 1));
139
+ }
140
+ fn getIndexFromCoords3D(coords : vec3<i32>, shape : vec3<i32>) -> i32 {
141
+ return dot(coords, vec3<i32>(shape.y * shape.z, shape.z, 1));
142
+ }
143
+ fn getIndexFromCoords4D(coords : vec4<i32>, shape : vec4<i32>) -> i32 {
144
+ return dot(coords, vec4<i32>(
145
+ shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));
146
+ }
147
+ fn getIndexFromCoords5D(coords : vec5, shape : vec5) -> i32 {
148
+ let shapeStrides: vec5 = vec5(shape.y * shape.z * shape.w * shape.u, shape.z * shape.w * shape.u, shape.w * shape.u, shape.u, 1);
149
+ return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u;
150
+ }
151
+ fn getIndexFromCoords6D(coords : vec6, shape : vec6) -> i32 {
152
+ let shapeStrides: vec6 = vec6(shape.y * shape.z * shape.w * shape.u * shape.v, shape.z * shape.w * shape.u * shape.v, shape.w * shape.u * shape.v, shape.u * shape.v, shape.v, 1);
153
+ return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u + coords.v*shapeStrides.v;
154
+ }
155
+
156
+ // NaN defination in IEEE 754-1985 is :
157
+ // - sign = either 0 or 1.
158
+ // - biased exponent = all 1 bits.
159
+ // - fraction = anything except all 0 bits (since all 0 bits represents infinity).
160
+ // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers
161
+ fn isnan(val: f32) -> bool {
162
+ let floatToUint: u32 = bitcast<u32>(val);
163
+ return (floatToUint & 0x7fffffffu) > 0x7f800000u;
164
+ }
165
+ fn isnanVec4(val : vec4<f32>) -> vec4<bool> {
166
+ let floatToUint: vec4<u32> = bitcast<vec4<u32>>(val);
167
+ return (floatToUint & vec4<u32>(0x7fffffffu)) > vec4<u32>(0x7f800000u);
168
+ }
169
+ `, T = `
170
+ fn isinf(val: f32) -> bool {
171
+ return abs(val) == uniforms.INFINITY;
172
+ }
173
+ `;
174
+ function j(s, t) {
175
+ const e = s.name, o = s.shape.length, u = g(o), a = "get" + e.charAt(0).toUpperCase() + e.slice(1), n = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), r = n.map((v) => `${v} : i32`).join(", ");
176
+ if (o < 1)
177
+ return `
178
+ fn ${a}() -> ${l(t)} {
179
+ return ${l(t)}(${e}[0]);
180
+ }
181
+ `;
182
+ const d = `uniforms.${e.charAt(0).toLowerCase() + e.slice(1)}Shape`;
183
+ let f = `${o}D`;
184
+ return o === 0 && (f = "1D"), `
185
+ fn ${a}(${r}) -> ${l(t)} {
186
+ return ${l(t)}(${e}[getIndexFromCoords${f}(${u}(${n.join(",")}),
187
+ ${d})${t === 1 ? "" : ` / ${t}`}]);
188
+ }
189
+ `;
190
+ }
191
+ function E(s, t, e, o) {
192
+ const u = s.name, a = u.charAt(0).toUpperCase() + u.slice(1), n = "get" + a + "ByOutput", r = s.shape.length, d = t.length, f = g(d);
193
+ if (L(s.shape, t) && o)
194
+ return `
195
+ fn ${n}Index(globalIndex : i32) -> ${l(e)} {
196
+ return ${l(e)}(${u}[globalIndex]);
197
+ }
198
+
199
+ fn ${n}Coords(coords : ${f}) -> ${l(e)} {
200
+ return ${l(e)}(${u}[${d > 1 ? "getOutputIndexFromCoords(coords)" : "coords"}${e === 1 ? "" : ` / ${e}`}]);
201
+ }
202
+ `;
203
+ const v = z(s.shape, t), $ = d - r;
204
+ let c = "";
205
+ if (r === 0)
206
+ return `
207
+ fn ${n}Index(globalIndex : i32) -> ${l(e)}{
208
+ return get${a}();
209
+ }
210
+
211
+ fn ${n}Coords(coords : ${f}) -> ${l(e)}{
212
+ return get${a}();
213
+ }
214
+ `;
215
+ d < 2 && v.length >= 1 ? c = "coords = 0;" : c = v.map((h) => `coords.${y(h + $)} = 0;`).join(`
216
+ `);
217
+ let p = "";
218
+ if (d < 2 && r > 0)
219
+ p = "coords";
220
+ else if (d > 1) {
221
+ const h = g(r), I = s.shape.map((M, k) => `coords.${y(k + $)}`).join(", ");
222
+ p = `${h}(${I})`;
223
+ } else
224
+ p = "coords";
225
+ const S = `uniforms.${u.charAt(0).toLowerCase() + u.slice(1)}Shape`, i = `${r}D`;
226
+ return `
227
+ fn ${n}Index(globalIndex : i32) -> ${l(e)} {
228
+ var coords = getCoordsFromIndex(globalIndex);
229
+ ${c}
230
+ return ${l(e)}(${u}[getIndexFromCoords${i}(${p}, ${S})${e === 1 ? "" : ` / ${e}`}]);
231
+ }
232
+
233
+ fn ${n}Coords(coordsIn : ${f}) -> ${l(e)} {
234
+ var coords = coordsIn;
235
+ ${c}
236
+ return ${l(e)}(${u}[getIndexFromCoords${i}(${p}, ${S})${e === 1 ? "" : ` / ${e}`}]);
237
+ }
238
+ `;
239
+ }
240
+ function U(s, t, e, o) {
241
+ let u = j(s, e);
242
+ return s.shape.length <= t.length && (u += E(s, t, e, o)), u;
243
+ }
244
+ function B(s, t) {
245
+ const { x: e, y: o = [], z: u = [] } = t, a = s.length, n = e.length + o.length + u.length;
246
+ if (n !== a)
247
+ return "";
248
+ if (e.length === a)
249
+ return `fn getOutputCoords() -> ${g(a)}{
250
+ let globalIndex = getGlobalIndex();
251
+ return getCoordsFromIndex(globalIndex);
252
+ }
253
+ `;
254
+ let r = "";
255
+ const d = [e, o, u];
256
+ for (let c = 0; c < d.length; c++) {
257
+ const p = d[c];
258
+ if (p.length !== 0)
259
+ if (p.length === 1)
260
+ r += `let d${p[0]} = i32(globalId[${c}]);`;
261
+ else {
262
+ const S = O(p, "uniforms.outShape");
263
+ r += `var index${c} = i32(globalId[${c}]);`;
264
+ for (let i = 0; i < S.length; i++)
265
+ r += `let d${p[i]} = index${c} / ${S[i]};`, i === S.length - 1 ? r += `let d${p[i + 1]} = index${c} - d${p[i]} * ${S[i]};` : r += `index${c} = index${c} - d${p[i]} * ${S[i]};`;
266
+ }
267
+ }
268
+ const f = [];
269
+ for (let c = 0; c < n; c++)
270
+ f.push(`d${c}`);
271
+ const v = g(n);
272
+ let $ = `fn getOutputCoords() -> ${v} {
273
+ ${r}
274
+ `;
275
+ return f.length === 0 ? $ += `return ${v}(0); }` : $ += `return ${v}(${f.join(",")}); }`, $;
276
+ }
277
+ function W(s) {
278
+ let t = "";
279
+ switch (s) {
280
+ case 0:
281
+ case 1:
282
+ t += `
283
+ fn getOutputIndexFromCoords(coords : i32) -> i32 {
284
+ return coords;
285
+ }
286
+ `;
287
+ break;
288
+ case 2:
289
+ t += `
290
+ fn getOutputIndexFromCoords(coords : vec2<i32>) -> i32 {
291
+ return dot(coords, vec2<i32>(uniforms.outShapeStrides, 1));
292
+ }
293
+ `;
294
+ break;
295
+ case 3:
296
+ t += `
297
+ fn getOutputIndexFromCoords(coords : vec3<i32>) -> i32 {
298
+ return dot(coords, vec3<i32>(uniforms.outShapeStrides.x, uniforms.outShapeStrides.y, 1));
299
+ }
300
+ `;
301
+ break;
302
+ case 4:
303
+ t += `
304
+ fn getOutputIndexFromCoords(coords : vec4<i32>) -> i32 {
305
+ return dot(coords, vec4<i32>(
306
+ uniforms.outShapeStrides.x, uniforms.outShapeStrides.y, uniforms.outShapeStrides.z, 1));
307
+ }
308
+ `;
309
+ break;
310
+ case 5:
311
+ t += `
312
+ fn getOutputIndexFromCoords(coords : vec5) -> i32 {
313
+ return coords.x * uniforms.outShapeStrides.x +
314
+ coords.y * uniforms.outShapeStrides.y +
315
+ coords.z * uniforms.outShapeStrides.z +
316
+ coords.w * uniforms.outShapeStrides.w +
317
+ coords.u;
318
+ }
319
+ `;
320
+ break;
321
+ case 6:
322
+ t += `
323
+ fn getOutputIndexFromCoords(coords : vec6) -> i32 {
324
+ return coords.x * uniforms.outShapeStrides.x +
325
+ coords.y * uniforms.outShapeStrides.y +
326
+ coords.z * uniforms.outShapeStrides.z +
327
+ coords.w * uniforms.outShapeStrides.w +
328
+ coords.u * uniforms.outShapeStrides.u +
329
+ coords.v;
330
+ }
331
+ `;
332
+ break;
333
+ default:
334
+ A(!1, () => `Unsupported ${s}D shape`);
335
+ break;
336
+ }
337
+ return t;
338
+ }
339
+ function R(s) {
340
+ return s.dispatch[1] === 1 && s.dispatch[2] === 1;
341
+ }
342
+ function G(s, t, e) {
343
+ const o = s.length, u = b(t, e);
344
+ let a = `fn setOutputAtIndex(flatIndex : i32, value : ${l(e)}) {
345
+ result[flatIndex] = ${u}(value);
346
+ }
347
+
348
+ fn setOutputAtIndexI32(flatIndex : i32, value : ${l(e, "i32")}) {
349
+ result[flatIndex] = ${u}(value);
350
+ }
351
+ `;
352
+ if (o >= 2) {
353
+ const n = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), r = g(o);
354
+ a += `
355
+ fn setOutputAtCoords(${n.map((d) => `${d} : i32`).join(", ")}, value : ${l(e)}) {
356
+ let flatIndex = getOutputIndexFromCoords(${r}(${n.join(", ")}));
357
+ setOutputAtIndex(flatIndex${e === 1 ? "" : ` / ${e}`}, value);
358
+ }
359
+ fn setOutputAtCoordsI32(${n.map((d) => `${d} : i32`).join(", ")}, value : ${l(e, "i32")}) {
360
+ let flatIndex = getOutputIndexFromCoords(${r}(${n.join(", ")}));
361
+ setOutputAtIndexI32(flatIndex${e === 1 ? "" : ` / ${e}`}, value);
362
+ }
363
+ `;
364
+ }
365
+ return a;
366
+ }
367
+ function P(s) {
368
+ const t = /(\w+)\s*:\s*vec(5|6)/g;
369
+ s = s.replace(t, (o) => "@align(16) " + o);
370
+ const e = /vec(5|6)\s*,\s*(\w+)/g;
371
+ return s = s.replace(e, (o, u, a) => `vec${u}, @align(16) ${a}`), s;
372
+ }
373
+ function w(s) {
374
+ return !(s.dispatchLayout.hasOwnProperty("y") && s.dispatchLayout.y?.length !== 0 || s.dispatchLayout.hasOwnProperty("z") && s.dispatchLayout.z?.length !== 0);
375
+ }
376
+ function m(s, t) {
377
+ return `
378
+ ${F(t)}
379
+ fn _start(@builtin(local_invocation_id) LocalId : vec3<u32>,
380
+ @builtin(global_invocation_id) GlobalId : vec3<u32>,
381
+ @builtin(local_invocation_index) LocalIndex: u32,
382
+ @builtin(workgroup_id) WorkgroupId : vec3<u32>,
383
+ ${t.subgroupBuiltins ? "@builtin(subgroup_invocation_id) SubgroupInvocationId : u32," : ""}
384
+ ${t.subgroupBuiltins ? "@builtin(subgroup_size) SubgroupSize : u32," : ""}
385
+ @builtin(num_workgroups) NumWorkgroups : vec3<u32>) {
386
+ localId = LocalId;
387
+ localIndex = LocalIndex;
388
+ globalId = GlobalId;
389
+ numWorkgroups = NumWorkgroups;
390
+ workgroupId = WorkgroupId;
391
+ ${t.subgroupBuiltins ? "subgroupInvocationId = SubgroupInvocationId" : ""};
392
+ ${t.subgroupBuiltins ? "subgroupSize = SubgroupSize" : ""};
393
+ ${s ? "main(getGlobalIndex());" : "main();"};
394
+ }
395
+ `;
396
+ }
397
+ export {
398
+ N as PixelsOpType,
399
+ K as compileProgram,
400
+ m as getStartHeaderString
401
+ };