@genai-fi/nanogpt 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +82 -64
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -0,0 +1,79 @@
1
+ import { f as s, c as n } from "../../../webgpu_util-BBCnKm2X.js";
2
+ import { g as p } from "../../../binary_op_util-pKXltfxI.js";
3
+ import { B as g } from "../../../binary_op_util-pKXltfxI.js";
4
+ import { j as u } from "../../../index-ZyQhjEPo.js";
5
+ import { e as l } from "../../../webgpu_program-Cigz-7RF.js";
6
+ import { l as r } from "../../../tensor-DdQUJZlz.js";
7
+ class y {
8
+ dispatch;
9
+ dispatchLayout;
10
+ outputComponent;
11
+ op;
12
+ outputShape;
13
+ shaderKey;
14
+ size = !0;
15
+ variableNames = ["A", "B"];
16
+ workgroupSize;
17
+ variableComponents;
18
+ constructor(e, t, o) {
19
+ this.outputShape = u(t, o), this.dispatchLayout = s(this.outputShape), this.op = e;
20
+ const i = t.length > 0 && t[t.length - 1] % 4 === 0, a = o.length > 0 && o[o.length - 1] % 4 === 0;
21
+ if (i && a)
22
+ this.outputComponent = 4, this.variableComponents = [4, 4];
23
+ else throw i && (r(o) || o[o.length - 1] === 1) || a && (r(t) || t[t.length - 1] === 1) ? new Error("Cannot broadcast 16-bit float binary ops with mixed vector sizes") : new Error("16-bit float binary ops require inner dimension to be multiple of 4");
24
+ this.shaderKey = `binary_${e}_${this.variableComponents}`, this.workgroupSize = [128, 1, 1], this.dispatch = n(this.dispatchLayout, this.outputShape, this.workgroupSize, [
25
+ this.outputComponent,
26
+ 1,
27
+ 1
28
+ ]);
29
+ }
30
+ getUserCode() {
31
+ const e = this.outputComponent === 4 ? "vec4<f32>" : "f32";
32
+ return `
33
+ ${`
34
+ fn binaryOperation(a : ${e}, b : ${e}) -> ${e} {
35
+ ${p(this.op, this.outputComponent === 4)}
36
+ };
37
+ `}
38
+ ${l("index")} {
39
+ if (index < uniforms.size) {
40
+ let a = A[index];
41
+ let b = B[index];
42
+
43
+ let v4a1 = vec4<f32>(
44
+ unpack2x16float(u32(a.x)),
45
+ unpack2x16float(u32(a.y))
46
+ );
47
+ let v4a2 = vec4<f32>(
48
+ unpack2x16float(u32(a.z)),
49
+ unpack2x16float(u32(a.w))
50
+ );
51
+ let v4b1 = vec4<f32>(
52
+ unpack2x16float(u32(b.x)),
53
+ unpack2x16float(u32(b.y))
54
+ );
55
+ let v4b2 = vec4<f32>(
56
+ unpack2x16float(u32(b.z)),
57
+ unpack2x16float(u32(b.w))
58
+ );
59
+
60
+ let v4res1 = binaryOperation(v4a1, v4b1);
61
+ let v4res2 = binaryOperation(v4a2, v4b2);
62
+
63
+ let res = vec4<i32>(
64
+ i32(pack2x16float(v4res1.xy)),
65
+ i32(pack2x16float(v4res1.zw)),
66
+ i32(pack2x16float(v4res2.xy)),
67
+ i32(pack2x16float(v4res2.zw))
68
+ );
69
+
70
+ result[index] = res;
71
+ }
72
+ }
73
+ `;
74
+ }
75
+ }
76
+ export {
77
+ y as BinaryOpProgram,
78
+ g as BinaryOpType
79
+ };
@@ -0,0 +1,7 @@
1
+ import { default as WebGPUBackendPatch } from '../../../patches/webgpu_backend';
2
+ export interface DeviceInformation {
3
+ subgroupsSupported: boolean;
4
+ subgroupMaxSize: number;
5
+ variableSubgroups: boolean;
6
+ }
7
+ export default function createDeviceInformation(backend: WebGPUBackendPatch): DeviceInformation;
@@ -0,0 +1,11 @@
1
+ function o(u) {
2
+ const r = u.device.features.has("subgroups");
3
+ return {
4
+ subgroupsSupported: r,
5
+ subgroupMaxSize: u.subgroupMaxSize,
6
+ variableSubgroups: u.subgroupMinSize !== u.subgroupMaxSize && r
7
+ };
8
+ }
9
+ export {
10
+ o as default
11
+ };
@@ -1,9 +1,37 @@
1
1
  import { backend_util, TensorInfo } from '@tensorflow/tfjs-core';
2
2
  import { WebGPUProgram } from '@tensorflow/tfjs-backend-webgpu/dist/webgpu_program';
3
3
  import { WebGPUBackend } from '@tensorflow/tfjs-backend-webgpu';
4
- export interface ReduceWebGPUProgram extends WebGPUProgram {
5
- inputShape: number[];
4
+ import { DeviceInformation } from './deviceInfo';
5
+ export interface ReduceParams {
6
+ reductionOp: 'mean' | 'sum';
7
+ elementwise?: boolean;
6
8
  }
7
- export declare function createReductionShader(workgroupSizeX: number, reductionOp: 'mean' | 'sum', inputSnippet: string, reducedSnippet: string, outputSnippet: string): string;
8
9
  export declare function createReduceInfo(inputs: TensorInfo[], axis: number | number[]): backend_util.ReduceInfo;
9
- export declare function reduce(program: ReduceWebGPUProgram, inputs: TensorInfo[], backend: WebGPUBackend): TensorInfo;
10
+ export declare class ReduceProgram implements WebGPUProgram {
11
+ outputShape: number[];
12
+ shaderKey: string;
13
+ dispatchLayout: {
14
+ x: number[];
15
+ };
16
+ dispatch: [number, number, number];
17
+ workgroupSize: [number, number, number];
18
+ variableNames: string[];
19
+ uniforms: string;
20
+ inputShape: number[];
21
+ size: boolean;
22
+ packed: boolean;
23
+ outputComponent: number;
24
+ variableComponents?: number[];
25
+ elementwise: boolean;
26
+ subgroups: boolean;
27
+ subgroupBuiltins: boolean;
28
+ deviceInfo: DeviceInformation;
29
+ params: ReduceParams;
30
+ constructor(deviceInfo: DeviceInformation, reduceInfo: backend_util.ReduceInfo, params: ReduceParams, packed: boolean);
31
+ protected getWriteSnippet(): string;
32
+ protected getPreprocessSnippet(): string;
33
+ protected getPostprocessSnippet(): string;
34
+ protected getReadSnippet(): string;
35
+ getUserCode(): string;
36
+ }
37
+ export declare function reduce(program: ReduceProgram, inputs: TensorInfo[], backend: WebGPUBackend): TensorInfo;
@@ -1,68 +1,259 @@
1
- import { p as l, j as d } from "../../../index-BzFyqcy-.js";
2
- import { g as p } from "../../../webgpu_program-DkQJOJSd.js";
3
- import { r as f } from "../../../Reshape-DUqYftGC.js";
4
- import { c as x } from "../../../axis_util-TbGYJ208.js";
5
- function I(e, r, t, s, u) {
1
+ import { e as l } from "../../../index-ZyQhjEPo.js";
2
+ import { e as a } from "../../../webgpu_program-Cigz-7RF.js";
3
+ import { reshape16 as S } from "../../reshape16.js";
4
+ import { f as b } from "../../../webgpu_util-BBCnKm2X.js";
5
+ import { p as f, s as c } from "../../../tensor-DdQUJZlz.js";
6
+ import { c as h } from "../../../axis_util-BvHEw88j.js";
7
+ function d(e, u, t, i) {
8
+ return e && !u ? `
9
+ bestValue = subgroupAdd(bestValue);
10
+ ` : e ? `
11
+ bestValue = subgroupAdd(bestValue);
12
+ let lane = localId.x % subgroupSize;
13
+ if (lane == 0) {
14
+ bestValues[localId.x / subgroupSize] = bestValue;
15
+ }
16
+ workgroupBarrier();
17
+ let numSubgroups = ${t} / subgroupSize;
18
+ bestValue = select(${i ? "vec2<f32>(0.0f)" : "0.0f"}, bestValues[lane], lane < numSubgroups);
19
+ bestValue = subgroupAdd(bestValue);
20
+ ` : `
21
+ bestValues[localId.x] = bestValue;
22
+ workgroupBarrier();
23
+
24
+ var reduceSize = min(u32(Length), ${t}u);
25
+ for (var currentSize = reduceSize / 2u; reduceSize > 1u;
26
+ currentSize = reduceSize / 2u) {
27
+ let interval = DIV_CEIL(reduceSize, 2u);
28
+ if (localId.x < currentSize) {
29
+ let candidate = bestValues[localId.x + interval];
30
+ bestValue = bestValue + candidate;
31
+ bestValues[localId.x] = bestValue;
32
+ }
33
+ reduceSize = interval;
34
+ workgroupBarrier();
35
+ }
36
+
37
+ bestValue = bestValues[0];
38
+ `;
39
+ }
40
+ function g(e) {
41
+ const u = `${e.workgroupSizeX}`, t = e.subgroups && !e.variableSubgroups ? "" : `
42
+ var<workgroup> bestValues : array<f32, ${e.workgroupSizeX}>;
43
+ `, i = d(e.subgroups, e.variableSubgroups, e.workgroupSizeX, !1);
6
44
  return `
7
45
  fn DIV_CEIL(a : u32, b : u32) -> u32 {
8
46
  return ((a - 1u) / b + 1u);
9
47
  }
10
48
 
11
- ${`
12
- var<workgroup> xBestValues : array<f32, ${e}>;
13
- `}
49
+ fn readInput(index: i32) -> vec2<f32> {
50
+ ${e.inputReadSnippet ? e.inputReadSnippet : `
51
+ let packed = u32(x[index]);
52
+ return unpack2x16float(packed);
53
+ `}
54
+ }
55
+
56
+ ${t}
14
57
 
15
- ${p("index")} {
16
- let outputIndex = index / ${e};
58
+ ${a("index")} {
59
+ let outputIndex = index / ${u};
17
60
  let offset = outputIndex * uniforms.reduceSize;
18
- var bestValue = 0.0;
61
+ var bestValue = 0.0f;
19
62
  let Length = uniforms.reduceSize;
63
+ let tid = i32(localId.x);
20
64
 
21
- for (var k = i32(localId.x); k < Length;
22
- k = k + ${e}) {
23
- var candidate = f32(x[offset + k]);
24
- ${t}
25
- bestValue = bestValue + candidate;
65
+ for (var k = tid; k < Length;
66
+ k = k + ${u}) {
67
+ var candidate = readInput(offset + k);
68
+ ${e.inputSnippet}
69
+ bestValue = bestValue + candidate.x + candidate.y;
26
70
  }
27
- xBestValues[localId.x] = bestValue;
28
- workgroupBarrier();
71
+
72
+ ${i}
73
+ bestValue = bestValue ${e.reductionOp === "mean" ? "/ f32(uniforms.reduceSize * 2i)" : ""};
74
+
75
+ ${e.reducedSnippet ? e.reducedSnippet : ""}
76
+
77
+ for (var k = tid; k < Length;
78
+ k = k + ${u}) {
79
+ ${e.outputSnippet}
80
+ }
81
+ }
82
+ `;
83
+ }
84
+ function k(e) {
85
+ const u = `${e.workgroupSizeX}`, t = e.subgroups && !e.variableSubgroups ? "" : `
86
+ var<workgroup> bestValues : array<vec2<f32>, ${e.workgroupSizeX}>;
87
+ `, i = d(e.subgroups, e.variableSubgroups, e.workgroupSizeX, !0);
88
+ return `
89
+ fn DIV_CEIL(a : u32, b : u32) -> u32 {
90
+ return ((a - 1u) / b + 1u);
91
+ }
92
+
93
+ fn readInput(index: i32) -> vec2<f32> {
94
+ ${e.inputReadSnippet ? e.inputReadSnippet : `
95
+ let packed = u32(x[index]);
96
+ return unpack2x16float(packed);
97
+ `}
98
+ }
99
+
100
+ ${t}
29
101
 
30
- var reduceSize = min(u32(Length), ${e}u);
31
- for (var currentSize = reduceSize / 2u; reduceSize > 1u;
32
- currentSize = reduceSize / 2u) {
33
- let interval = DIV_CEIL(reduceSize, 2u);
34
- if (localId.x < currentSize) {
35
- let candidate = xBestValues[localId.x + interval];
36
- bestValue = bestValue + candidate;
37
- xBestValues[localId.x] = bestValue;
38
- }
39
- reduceSize = interval;
40
- workgroupBarrier();
102
+ ${a("index")} {
103
+ let outputIndex = index / ${u};
104
+ let offset1 = outputIndex * 2 * uniforms.reduceSize;
105
+ let offset2 = offset1 + uniforms.reduceSize;
106
+ var bestValue = vec2<f32>(0.0f, 0.0f);
107
+ let Length = uniforms.reduceSize;
108
+ let tid = i32(localId.x);
109
+
110
+ for (var k = tid; k < Length;
111
+ k = k + ${u}) {
112
+ var candidate = readInput(offset1 + k);
113
+ ${e.inputSnippet}
114
+ let bv1 = candidate.x + candidate.y;
115
+
116
+ candidate = readInput(offset2 + k);
117
+ ${e.inputSnippet}
118
+ let bv2 = candidate.x + candidate.y;
119
+
120
+ bestValue = bestValue + vec2<f32>(bv1, bv2);
121
+ }
122
+ ${i}
123
+ bestValue = bestValue ${e.reductionOp === "mean" ? "/ f32(uniforms.reduceSize * 2i)" : ""};
124
+
125
+ ${e.reducedSnippet ?? ""}
126
+ ${e.outputSnippet}
127
+ }
128
+ `;
129
+ }
130
+ function z(e) {
131
+ return e.elementwise ? g(e) : k(e);
132
+ }
133
+ function x(e) {
134
+ const u = `${e.workgroupSizeX}`, t = e.subgroups && !e.variableSubgroups ? "" : `
135
+ var<workgroup> bestValues : array<f32, ${e.workgroupSizeX}>;
136
+ `, i = d(e.subgroups, e.variableSubgroups, e.workgroupSizeX, !1);
137
+ return `
138
+ fn DIV_CEIL(a : u32, b : u32) -> u32 {
139
+ return ((a - 1u) / b + 1u);
140
+ }
141
+
142
+ fn readInput(index: i32) -> f32 {
143
+ ${e.inputReadSnippet ? e.inputReadSnippet : `
144
+ return x[index];
145
+ `}
146
+ }
147
+
148
+ ${t}
149
+
150
+ ${a("index")} {
151
+ let outputIndex = index / ${e.workgroupSizeX};
152
+ let offset = outputIndex * uniforms.reduceSize;
153
+ var bestValue = 0.0f;
154
+ let Length = uniforms.reduceSize;
155
+ let tid = i32(localId.x);
156
+
157
+ for (var k = tid; k < Length;
158
+ k = k + ${e.workgroupSizeX}) {
159
+ var candidate = readInput(offset + k);
160
+ ${e.inputSnippet}
161
+ bestValue = bestValue + candidate;
41
162
  }
163
+ ${i}
42
164
 
43
- bestValue = xBestValues[0] ${r === "mean" ? "/ f32(uniforms.reduceSize)" : ""};
165
+ bestValue = bestValue ${e.reductionOp === "mean" ? "/ f32(uniforms.reduceSize)" : ""};
44
166
 
45
- ${s}
167
+ ${e.reducedSnippet}
46
168
 
47
- for (var k = i32(localId.x); k < Length;
48
- k = k + ${e}) {
49
- ${u}
169
+ for (var k = tid; k < Length;
170
+ k = k + ${u}) {
171
+ ${e.outputSnippet}
50
172
  }
51
173
  }
52
174
  `;
53
175
  }
54
- function V(e, r) {
55
- const t = e[0], u = l(r, t.shape), [, n] = x(t.shape, u), a = d(n), i = d(t.shape) / a;
56
- return { windowSize: a, inSize: a, batchSize: i, outSize: 1 };
176
+ function C(e, u) {
177
+ const t = e[0], n = f(u, t.shape), [, r] = h(t.shape, n), o = c(r), p = c(t.shape) / o;
178
+ return { windowSize: o, inSize: o, batchSize: p, outSize: 1 };
179
+ }
180
+ class y {
181
+ outputShape;
182
+ shaderKey = "reduce16";
183
+ dispatchLayout;
184
+ dispatch;
185
+ workgroupSize = [64, 1, 1];
186
+ variableNames = ["x"];
187
+ uniforms = "reduceSize : i32,";
188
+ inputShape;
189
+ size = !1;
190
+ packed = !0;
191
+ outputComponent;
192
+ variableComponents;
193
+ elementwise;
194
+ subgroups = !1;
195
+ subgroupBuiltins = !1;
196
+ deviceInfo;
197
+ params;
198
+ constructor(u, t, i, n) {
199
+ this.params = i, this.inputShape = [t.batchSize, t.inSize], this.deviceInfo = u, this.packed = n;
200
+ const r = t.inSize % 64 === 0 ? 64 : 32;
201
+ u.subgroupsSupported ? (this.workgroupSize = [Math.min(r, u.subgroupMaxSize), 1, 1], this.subgroups = !0, u.variableSubgroups && (this.subgroupBuiltins = !0)) : this.workgroupSize[0] = r, this.outputShape = i.elementwise ? [t.batchSize, t.inSize] : [t.batchSize / 2], this.dispatchLayout = b(this.outputShape), this.dispatch = [i.elementwise ? t.batchSize : t.batchSize / 2, 1, 1], this.outputComponent = 1, this.variableComponents = [1], this.elementwise = i.elementwise === !0;
202
+ }
203
+ getWriteSnippet() {
204
+ return this.packed ? "result[outputIndex] = i32(pack2x16float(bestValue));" : "result[outputIndex] = bestValue;";
205
+ }
206
+ getPreprocessSnippet() {
207
+ return "";
208
+ }
209
+ getPostprocessSnippet() {
210
+ return "";
211
+ }
212
+ getReadSnippet() {
213
+ return this.packed ? `
214
+ let packed = u32(x[index]);
215
+ return unpack2x16float(packed);
216
+ ` : "return x[index];";
217
+ }
218
+ getUserCode() {
219
+ const u = this.workgroupSize[0];
220
+ return this.packed ? z({
221
+ ...this.params,
222
+ workgroupSizeX: u,
223
+ subgroups: this.subgroups,
224
+ variableSubgroups: this.deviceInfo.variableSubgroups,
225
+ inputReadSnippet: this.getReadSnippet(),
226
+ inputSnippet: this.getPreprocessSnippet(),
227
+ outputSnippet: this.getWriteSnippet(),
228
+ reducedSnippet: this.getPostprocessSnippet()
229
+ }) : x({
230
+ ...this.params,
231
+ workgroupSizeX: u,
232
+ subgroups: this.subgroups,
233
+ variableSubgroups: this.deviceInfo.variableSubgroups,
234
+ inputReadSnippet: this.getReadSnippet(),
235
+ inputSnippet: this.getPreprocessSnippet(),
236
+ outputSnippet: this.getWriteSnippet(),
237
+ reducedSnippet: this.getPostprocessSnippet()
238
+ });
239
+ }
57
240
  }
58
- function $(e, r, t) {
59
- const s = [], u = r[0], a = [{ type: "int32", data: [e.inputShape[1]] }], o = t.runWebGPUProgram(e, r, "float32", a);
60
- s.push(o);
61
- const i = f({ inputs: { x: o }, attrs: { shape: u.shape }, backend: t });
62
- return s.forEach((c) => t.disposeData(c.dataId)), i;
241
+ function X(e, u, t) {
242
+ const i = u[0], r = [{ type: "int32", data: [e.inputShape[e.inputShape.length - 1]] }], o = t.runWebGPUProgram(
243
+ e,
244
+ u,
245
+ e.packed ? "int32" : "float32",
246
+ r
247
+ );
248
+ o.packed = e.packed ?? !1;
249
+ const s = l().makeTensorFromTensorInfo(o), p = S(
250
+ s,
251
+ e.elementwise ? i.shape : e.packed ? [...i.shape.slice(0, -2), i.shape[i.shape.length - 2] / 2] : [...i.shape.slice(0, -2), i.shape[i.shape.length - 2]]
252
+ );
253
+ return s.dispose(), p;
63
254
  }
64
255
  export {
65
- V as createReduceInfo,
66
- I as createReductionShader,
67
- $ as reduce
256
+ y as ReduceProgram,
257
+ C as createReduceInfo,
258
+ X as reduce
68
259
  };