@genai-fi/nanogpt 0.19.1 → 0.20.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
  2. package/dist/DatasetBuilder-DgURD85T.js +712 -0
  3. package/dist/Generator.js +2 -11941
  4. package/dist/RealDiv-DBu0FQqT.js +362 -0
  5. package/dist/Reshape-CABOPB9d.js +94 -0
  6. package/dist/Reshape-DqO3r8BC.js +17 -0
  7. package/dist/TeachableLLM.d.ts +5 -5
  8. package/dist/TeachableLLM.js +2 -273
  9. package/dist/Trainer.js +2 -244
  10. package/dist/backend.js +12 -12
  11. package/dist/backend_util-Cg-roD1p.js +399 -0
  12. package/dist/binary_op_util-CrYk9LXL.js +103 -0
  13. package/dist/checks/appendCache.js +54 -21
  14. package/dist/checks/attentionMask.js +55 -36
  15. package/dist/checks/check.js +31 -19
  16. package/dist/checks/gelu.js +45 -17
  17. package/dist/checks/index.js +25 -25
  18. package/dist/checks/matMulGelu.js +83 -27
  19. package/dist/checks/normRMS.js +27 -15
  20. package/dist/checks/normRMSGrad.js +21 -11
  21. package/dist/checks/packUnpack.js +45 -17
  22. package/dist/checks/qkv.js +33 -33
  23. package/dist/checks/rope.js +29 -35
  24. package/dist/checks/weights.d.ts +1 -1
  25. package/dist/checks/weights.js +25 -29
  26. package/dist/chunk-BPntVaq0.js +23 -0
  27. package/dist/complex_util-CkazZsaH.js +60 -0
  28. package/dist/concat_util-CWDZCBlA.js +19 -0
  29. package/dist/data/docx.js +3044 -13
  30. package/dist/data/pdf.js +16 -13
  31. package/dist/data/textLoader.js +607 -112
  32. package/dist/dist-BewPQWjc.js +7572 -0
  33. package/dist/dist-DVmq73nz.js +8775 -0
  34. package/dist/dist-DXwIvKxl.js +896 -0
  35. package/dist/dist-VEU5mfO0.js +7545 -0
  36. package/dist/gelu-Bf1HW1RY.js +27 -0
  37. package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
  38. package/dist/inference/types.js +0 -1
  39. package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
  40. package/dist/layers/BaseLayer.d.ts +2 -2
  41. package/dist/layers/BaseLayer.js +75 -73
  42. package/dist/layers/CausalSelfAttention.d.ts +1 -1
  43. package/dist/layers/CausalSelfAttention.js +98 -85
  44. package/dist/layers/LoRA.js +47 -57
  45. package/dist/layers/MLP.d.ts +1 -1
  46. package/dist/layers/MLP.js +33 -43
  47. package/dist/layers/PositionEmbedding.d.ts +1 -1
  48. package/dist/layers/PositionEmbedding.js +26 -30
  49. package/dist/layers/RMSNorm.d.ts +1 -1
  50. package/dist/layers/RMSNorm.js +19 -21
  51. package/dist/layers/RoPECache.js +336 -49
  52. package/dist/layers/TiedEmbedding.d.ts +1 -1
  53. package/dist/layers/TiedEmbedding.js +30 -34
  54. package/dist/layers/TransformerBlock.d.ts +1 -1
  55. package/dist/layers/TransformerBlock.js +50 -39
  56. package/dist/layers/WeightStore.js +68 -75
  57. package/dist/loader/load.js +2 -68
  58. package/dist/loader/loadHF.d.ts +2 -2
  59. package/dist/loader/loadHF.js +2 -22
  60. package/dist/loader/loadTransformers.d.ts +1 -1
  61. package/dist/loader/loadTransformers.js +2 -44
  62. package/dist/loader/loadZipMeta.js +15 -15
  63. package/dist/loader/newZipLoad.js +2 -31
  64. package/dist/loader/oldZipLoad.d.ts +2 -2
  65. package/dist/loader/oldZipLoad.js +2 -80
  66. package/dist/loader/save.d.ts +5 -5
  67. package/dist/loader/save.js +2 -90
  68. package/dist/loader/types.d.ts +9 -8
  69. package/dist/loader/types.js +0 -1
  70. package/dist/main-CPjeMv0G.js +13500 -0
  71. package/dist/main.d.ts +1 -1
  72. package/dist/main.js +16 -109
  73. package/dist/matMul16-BNfZSnNM.js +81 -0
  74. package/dist/matMulGelu-CPTntosE.js +162 -0
  75. package/dist/models/NanoGPTV1.js +2 -99
  76. package/dist/models/NanoGPTV2.js +2 -90
  77. package/dist/models/config.js +34 -47
  78. package/dist/models/factory.d.ts +1 -1
  79. package/dist/models/factory.js +2 -16
  80. package/dist/models/model.d.ts +2 -2
  81. package/dist/models/model.js +2 -134
  82. package/dist/ops/adamAdjust.js +15 -6
  83. package/dist/ops/adamMoments.js +13 -6
  84. package/dist/ops/add16.js +9 -6
  85. package/dist/ops/appendCache.js +22 -19
  86. package/dist/ops/attentionMask.js +12 -6
  87. package/dist/ops/concat16.js +7 -8
  88. package/dist/ops/cpu/adamAdjust.js +15 -17
  89. package/dist/ops/cpu/adamMoments.js +15 -15
  90. package/dist/ops/cpu/appendCache.js +64 -22
  91. package/dist/ops/cpu/attentionMask.js +15 -21
  92. package/dist/ops/cpu/fusedSoftmax.js +21 -28
  93. package/dist/ops/cpu/gatherSub.js +11 -17
  94. package/dist/ops/cpu/gelu.js +34 -38
  95. package/dist/ops/cpu/matMul16.js +13 -14
  96. package/dist/ops/cpu/matMulGelu.js +39 -51
  97. package/dist/ops/cpu/matMulMul.js +19 -22
  98. package/dist/ops/cpu/mulDropout.js +19 -22
  99. package/dist/ops/cpu/normRMS.js +33 -37
  100. package/dist/ops/cpu/qkv.js +72 -40
  101. package/dist/ops/cpu/rope.js +79 -36
  102. package/dist/ops/cpu/scatterSub.js +11 -22
  103. package/dist/ops/dot16.js +28 -41
  104. package/dist/ops/dropout.js +10 -13
  105. package/dist/ops/dropout16.js +20 -23
  106. package/dist/ops/gatherSub.js +10 -6
  107. package/dist/ops/gelu.js +2 -8
  108. package/dist/ops/globalNorm.js +18 -12
  109. package/dist/ops/grads/add16.js +27 -26
  110. package/dist/ops/grads/attentionMask.js +26 -21
  111. package/dist/ops/grads/dropout16.js +0 -1
  112. package/dist/ops/grads/gelu.js +2 -5
  113. package/dist/ops/grads/matMul16.js +2 -9
  114. package/dist/ops/grads/matMulGelu.js +21 -16
  115. package/dist/ops/grads/mul16.js +0 -3
  116. package/dist/ops/grads/normRMS.js +34 -30
  117. package/dist/ops/grads/pack16.js +2 -6
  118. package/dist/ops/grads/qkv.js +44 -32
  119. package/dist/ops/grads/rope.js +2 -5
  120. package/dist/ops/grads/softmax16.js +21 -23
  121. package/dist/ops/grads/unpack16.js +2 -5
  122. package/dist/ops/grads/utils.js +9 -11
  123. package/dist/ops/matMul16.js +2 -13
  124. package/dist/ops/matMulGelu.js +16 -10
  125. package/dist/ops/matMulMul.js +13 -6
  126. package/dist/ops/mul16.js +42 -38
  127. package/dist/ops/mulDrop.js +12 -6
  128. package/dist/ops/normRMS.js +18 -15
  129. package/dist/ops/pack16.js +2 -5
  130. package/dist/ops/qkv.js +12 -6
  131. package/dist/ops/reshape16.js +31 -39
  132. package/dist/ops/rope.d.ts +1 -1
  133. package/dist/ops/rope.js +2 -7
  134. package/dist/ops/scatterSub.js +10 -6
  135. package/dist/ops/slice16.js +9 -7
  136. package/dist/ops/softmax16.js +7 -7
  137. package/dist/ops/sub16.js +9 -6
  138. package/dist/ops/sum16.js +12 -12
  139. package/dist/ops/transpose16.js +29 -37
  140. package/dist/ops/unpack16.js +2 -6
  141. package/dist/ops/webgl/adamAdjust.js +62 -29
  142. package/dist/ops/webgl/adamMoments.js +30 -26
  143. package/dist/ops/webgl/appendCache.js +30 -21
  144. package/dist/ops/webgl/attentionMask.js +43 -24
  145. package/dist/ops/webgl/dropout16.js +11 -10
  146. package/dist/ops/webgl/fusedSoftmax.js +69 -79
  147. package/dist/ops/webgl/gatherSub.js +27 -26
  148. package/dist/ops/webgl/gelu.js +32 -34
  149. package/dist/ops/webgl/log.js +14 -23
  150. package/dist/ops/webgl/matMul16.js +36 -44
  151. package/dist/ops/webgl/matMulGelu.js +2 -9
  152. package/dist/ops/webgl/matMulMul.js +23 -27
  153. package/dist/ops/webgl/mulDropout.js +31 -40
  154. package/dist/ops/webgl/normRMS.js +92 -71
  155. package/dist/ops/webgl/qkv.js +35 -27
  156. package/dist/ops/webgl/rope.js +37 -21
  157. package/dist/ops/webgl/scatterSub.js +27 -26
  158. package/dist/ops/webgpu/adamAdjust.js +59 -39
  159. package/dist/ops/webgpu/adamMoments.js +62 -46
  160. package/dist/ops/webgpu/add16.js +13 -12
  161. package/dist/ops/webgpu/appendCache.js +79 -54
  162. package/dist/ops/webgpu/attentionMask.js +41 -25
  163. package/dist/ops/webgpu/attentionMask32_program.js +34 -26
  164. package/dist/ops/webgpu/clipScale.js +44 -57
  165. package/dist/ops/webgpu/concat16.js +96 -111
  166. package/dist/ops/webgpu/dropout16.js +40 -32
  167. package/dist/ops/webgpu/gatherSub.js +43 -30
  168. package/dist/ops/webgpu/gelu.js +88 -82
  169. package/dist/ops/webgpu/index.js +16 -16
  170. package/dist/ops/webgpu/matMul16.js +69 -64
  171. package/dist/ops/webgpu/matMul16_program.js +152 -192
  172. package/dist/ops/webgpu/mul16.js +13 -12
  173. package/dist/ops/webgpu/norm2.js +45 -75
  174. package/dist/ops/webgpu/normRMS.js +25 -33
  175. package/dist/ops/webgpu/normRMS16_program.js +21 -18
  176. package/dist/ops/webgpu/normRMS32_program.js +21 -18
  177. package/dist/ops/webgpu/normRMSGrad.js +125 -184
  178. package/dist/ops/webgpu/pack16.js +20 -17
  179. package/dist/ops/webgpu/pack16_program.js +48 -47
  180. package/dist/ops/webgpu/qkv.js +63 -23
  181. package/dist/ops/webgpu/rope.js +85 -57
  182. package/dist/ops/webgpu/scatterSub.js +43 -30
  183. package/dist/ops/webgpu/slice16.js +66 -61
  184. package/dist/ops/webgpu/softmax16.js +17 -20
  185. package/dist/ops/webgpu/softmax16_program.js +34 -18
  186. package/dist/ops/webgpu/softmax16_subgroup_program.js +40 -45
  187. package/dist/ops/webgpu/softmax16grad.js +30 -36
  188. package/dist/ops/webgpu/sub16.js +13 -12
  189. package/dist/ops/webgpu/sum16.js +28 -37
  190. package/dist/ops/webgpu/transpose16.js +36 -33
  191. package/dist/ops/webgpu/transpose16_program.js +40 -39
  192. package/dist/ops/webgpu/transpose16_shared_program.js +53 -44
  193. package/dist/ops/webgpu/unpack16.js +49 -37
  194. package/dist/ops/webgpu/utils/binary_op.js +70 -68
  195. package/dist/ops/webgpu/utils/deviceInfo.d.ts +1 -1
  196. package/dist/ops/webgpu/utils/deviceInfo.js +10 -10
  197. package/dist/ops/webgpu/utils/reductions.js +136 -148
  198. package/dist/pack16-Ck-spx_F.js +39 -0
  199. package/dist/patches/webgpu_backend.d.ts +2 -2
  200. package/dist/patches/webgpu_backend.js +42 -55
  201. package/dist/patches/webgpu_base.js +21 -33
  202. package/dist/patches/webgpu_program.js +213 -320
  203. package/dist/pdf-UoDqCYzz.js +16726 -0
  204. package/dist/picomatch-3tUnMMbd.js +1063 -0
  205. package/dist/rope-CbeGlsV8.js +25 -0
  206. package/dist/selu_util-zkAx5doH.js +24 -0
  207. package/dist/shared-D1coEFea.js +1314 -0
  208. package/dist/shared-DOgWaqvL.js +5 -0
  209. package/dist/slice_util-Dgb3ANWI.js +208 -0
  210. package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
  211. package/dist/tokeniser/BaseTokeniser.js +2 -124
  212. package/dist/tokeniser/CharTokeniser.js +91 -106
  213. package/dist/tokeniser/bpe.js +163 -166
  214. package/dist/tokeniser/messages.js +0 -1
  215. package/dist/tokeniser/type.js +0 -1
  216. package/dist/training/AdamW.js +127 -137
  217. package/dist/training/BasicTrainer.d.ts +1 -1
  218. package/dist/training/BasicTrainer.js +264 -264
  219. package/dist/training/DatasetBuilder.js +2 -86
  220. package/dist/training/Evaluator.d.ts +1 -1
  221. package/dist/training/Evaluator.js +47 -38
  222. package/dist/training/LRScheduler.js +37 -33
  223. package/dist/training/PreTrainer.d.ts +2 -2
  224. package/dist/training/PreTrainer.js +21 -19
  225. package/dist/training/SFTTrainer.d.ts +2 -2
  226. package/dist/training/SFTTrainer.js +23 -21
  227. package/dist/training/loss.js +17 -22
  228. package/dist/training/orthoGrad.js +9 -9
  229. package/dist/training/sparseCrossEntropy.js +45 -67
  230. package/dist/training/tasks/ConversationTask.d.ts +1 -1
  231. package/dist/training/tasks/ConversationTask.js +36 -38
  232. package/dist/training/tasks/PretrainingTask.d.ts +1 -1
  233. package/dist/training/tasks/PretrainingTask.js +41 -46
  234. package/dist/training/tasks/StartSentenceTask.d.ts +1 -1
  235. package/dist/training/tasks/StartSentenceTask.js +44 -48
  236. package/dist/training/tasks/Task.d.ts +1 -1
  237. package/dist/training/tasks/Task.js +53 -66
  238. package/dist/training/tasks/splitter.js +17 -20
  239. package/dist/training/types.d.ts +2 -2
  240. package/dist/training/types.js +0 -1
  241. package/dist/training/validation.d.ts +1 -1
  242. package/dist/training/validation.js +2 -84
  243. package/dist/utilities/arrayClose.js +15 -19
  244. package/dist/utilities/datasetID.js +17 -20
  245. package/dist/utilities/dummy.d.ts +1 -1
  246. package/dist/utilities/dummy.js +33 -40
  247. package/dist/utilities/multinomialCPU.js +8 -12
  248. package/dist/utilities/naming.js +0 -1
  249. package/dist/utilities/packed.js +10 -12
  250. package/dist/utilities/parameters.d.ts +1 -1
  251. package/dist/utilities/parameters.js +32 -51
  252. package/dist/utilities/performance.js +15 -15
  253. package/dist/utilities/profile.js +32 -37
  254. package/dist/utilities/safetensors.js +49 -79
  255. package/dist/utilities/sentences.d.ts +1 -1
  256. package/dist/utilities/sentences.js +29 -38
  257. package/dist/utilities/tokenParse.js +16 -20
  258. package/dist/utilities/topP.js +11 -12
  259. package/dist/utilities/waitForModel.d.ts +1 -1
  260. package/dist/utilities/waitForModel.js +11 -11
  261. package/dist/utilities/weights.js +37 -42
  262. package/dist/utilities/yielder.js +6 -6
  263. package/dist/webgpu-Dt7BMzWz.js +525 -0
  264. package/dist/webgpu_program-WOyIVMlZ.js +392 -0
  265. package/dist/webgpu_util-B_F3SShA.js +106 -0
  266. package/package.json +9 -10
  267. package/dist/RealDiv-CGwv0liw.js +0 -365
  268. package/dist/Reshape-BW__R4mZ.js +0 -79
  269. package/dist/Reshape-CPBkTIH2.js +0 -14
  270. package/dist/_commonjsHelpers-ByX85dGu.js +0 -33
  271. package/dist/axis_util-GTVlo58H.js +0 -55
  272. package/dist/backend_util-GaFarB78.js +0 -425
  273. package/dist/backend_webgpu-BqASlsbV.js +0 -545
  274. package/dist/binary_op_util-pKXltfxI.js +0 -192
  275. package/dist/broadcast_to-eS93CCN_.js +0 -28
  276. package/dist/clip_by_value-DDA7rrcT.js +0 -12
  277. package/dist/complex-DI35Q-gW.js +0 -11
  278. package/dist/complex_util-Yc1A_gV1.js +0 -55
  279. package/dist/concat-CAQpCret.js +0 -17
  280. package/dist/concat_util-D18dJ4fD.js +0 -22
  281. package/dist/data/parquet.d.ts +0 -2
  282. package/dist/data/parquet.js +0 -17
  283. package/dist/dataset-CGGp1z9P.js +0 -1124
  284. package/dist/dropout_util--NxWuYg2.js +0 -27
  285. package/dist/expand_dims-Bkd1YD5x.js +0 -11
  286. package/dist/exports_initializers-CYzKLjN7.js +0 -7
  287. package/dist/floor-BQtb-Azg.js +0 -9
  288. package/dist/gather-qIqEqaGn.js +0 -9
  289. package/dist/gelu-B220X1Go.js +0 -26
  290. package/dist/gpgpu_math-BwvV12df.js +0 -2022
  291. package/dist/index-CUXkjxiT.js +0 -3516
  292. package/dist/index-CieiGp4Y.js +0 -349
  293. package/dist/index-CjOWnMXP.js +0 -7308
  294. package/dist/index-Cp39cXWe.js +0 -1016
  295. package/dist/index-D5v913EJ.js +0 -4
  296. package/dist/index-DmeWGGmS.js +0 -1074
  297. package/dist/index-DvYrXKkX.js +0 -113
  298. package/dist/index-Ksja3su6.js +0 -151
  299. package/dist/index-xuotMAFm.js +0 -118
  300. package/dist/jszip.min-BZhlzntC.js +0 -2313
  301. package/dist/kernel_funcs_utils-pq0CK9co.js +0 -306
  302. package/dist/matMul16-BcVC_E62.js +0 -80
  303. package/dist/matMulGelu-JNLZqKQp.js +0 -163
  304. package/dist/mat_mul-DhG0Newp.js +0 -11
  305. package/dist/mod-CSdCpRjf.js +0 -11
  306. package/dist/non_max_suppression_impl-B2W7YjZB.js +0 -102
  307. package/dist/not_equal-hurPF26l.js +0 -64
  308. package/dist/ones-BytntneX.js +0 -14
  309. package/dist/ops-CsXeTq1P.js +0 -476
  310. package/dist/pack16-bqltoUlR.js +0 -39
  311. package/dist/papaparse.min-C0cScC2i.js +0 -418
  312. package/dist/parquet-Bqjmp2vo.js +0 -44231
  313. package/dist/pdf-NIhmP3sq.js +0 -19477
  314. package/dist/rand_util-CZ7yLoUm.js +0 -50
  315. package/dist/random_normal-IBRrha8a.js +0 -14
  316. package/dist/random_width-DN5ZtQkM.js +0 -9796
  317. package/dist/range-C-CjF-LI.js +0 -10
  318. package/dist/relu-J_X6MUzx.js +0 -9
  319. package/dist/reshape-BDOuCSNW.js +0 -9
  320. package/dist/resize_nearest_neighbor-BojqlfRe.js +0 -150
  321. package/dist/rope-DcrZM_e6.js +0 -24
  322. package/dist/scatter_nd_util-ByNJaL6I.js +0 -46
  323. package/dist/segment_util-Dasb2Zaf.js +0 -43
  324. package/dist/selu_util-BLhIqRkw.js +0 -44
  325. package/dist/shared-3agzAqQ_.js +0 -53
  326. package/dist/shared-CagdqkLh.js +0 -2143
  327. package/dist/slice-BzS11Qh0.js +0 -12
  328. package/dist/slice_util-CC35pLmT.js +0 -153
  329. package/dist/softmax-D4q1LJN7.js +0 -12
  330. package/dist/split-C2Sj255c.js +0 -9
  331. package/dist/squeeze-ho4wLUek.js +0 -10
  332. package/dist/stack-DudVrtmG.js +0 -11
  333. package/dist/step-BTxPtq1r.js +0 -261
  334. package/dist/sum-BpiwSWvg.js +0 -11
  335. package/dist/tensor-BWFldCso.js +0 -8
  336. package/dist/tensor1d-LMGMIUlr.js +0 -11
  337. package/dist/tensor2d-BnXMKScO.js +0 -14
  338. package/dist/tensor4d-C6UCG_u8.js +0 -14
  339. package/dist/tfjs_backend-BGnG-ppu.js +0 -654
  340. package/dist/tile-CFy-xTO6.js +0 -11
  341. package/dist/transpose-9kRxIXWR.js +0 -36
  342. package/dist/unsorted_segment_sum-DJvk5xnh.js +0 -277
  343. package/dist/variable-Ck482e3n.js +0 -7
  344. package/dist/webgpu_program-B4HmApL1.js +0 -525
  345. package/dist/webgpu_util-DYlGSwOJ.js +0 -64
  346. package/dist/zeros-DvZpK8s6.js +0 -13
  347. package/dist/zeros_like-CWjDdwr-.js +0 -721
@@ -1,65 +1,70 @@
1
- import { c as V, m as U, a as g, n as X, U as y, _ as Y } from "../../index-CUXkjxiT.js";
2
- import { isPackedTensor as x } from "../../utilities/packed.js";
3
- import { reshape16 as C } from "../reshape16.js";
4
- import { matMulMul as Z } from "../matMulMul.js";
5
- import { matMulGelu as $ } from "../matMulGelu.js";
6
- import ee from "./matMul16_program.js";
7
- import { m as v } from "../../mat_mul-DhG0Newp.js";
8
- import { r as E } from "../../relu-J_X6MUzx.js";
9
- import { r as F } from "../../reshape-BDOuCSNW.js";
10
- import { t as W } from "../../transpose-9kRxIXWR.js";
11
- function se(m) {
12
- const { A: s, B: t } = m.inputs, { transposeA: p, transposeB: i, scale: h, activation: c, scaleA: d, scaleB: f, forceOutputShape: r, perm: k, causalMask: w, pastLen: z } = m.attrs, I = m.backend, B = !x(s), b = !x(t);
13
- if (B && b) {
14
- const u = d !== void 0 ? U(s, g(d)) : s, l = f !== void 0 ? U(t, g(f)) : t;
15
- if (w)
16
- throw new Error("Causal mask is not supported for unpacked MatMul16.");
17
- let e;
18
- if (h !== void 0)
19
- e = Z(u, l, g(h), p, i);
20
- else if (c === "gelu")
21
- e = $(u, l);
22
- else if (c === "relu2") {
23
- const o = v(u, l, p, i), A = E(o);
24
- o.dispose(), e = X(A), A.dispose();
25
- } else c === "relu" ? e = E(v(u, l, p, i)) : e = v(u, l, p, i);
26
- if (k)
27
- if (r) {
28
- const o = F(e, r);
29
- e.dispose();
30
- const A = W(o, k);
31
- return o.dispose(), A;
32
- } else {
33
- const o = W(e, k);
34
- return e.dispose(), o;
35
- }
36
- else if (r) {
37
- const o = F(e, r);
38
- return e.dispose(), o;
39
- } else
40
- return e;
41
- }
42
- if (B && !b)
43
- throw new Error("When using mixed precision, A must be packed if B is packed.");
44
- if (!B && b)
45
- throw new Error("When using mixed precision, B must be packed if A is packed.");
46
- const M = s.shape.length, S = t.shape.length, N = s.shape.slice(0, -2), q = t.shape.slice(0, -2), R = y(N), D = y(q), K = Y(s.shape.slice(0, -2), t.shape.slice(0, -2)), L = Math.max(R, D), T = s.shape[M - 2], _ = t.shape[S - 2], j = s.shape[M - 1] * 2, H = t.shape[S - 1] * 2, G = C(s, [R, s.shape[M - 2], s.shape[M - 1]]), O = C(t, [D, t.shape[S - 2], t.shape[S - 1]]), a = new ee(L, T, _, j, H, p, i), n = [];
47
- h !== void 0 && (a.useScale(), n.push({ type: "float32", data: [h] })), d !== void 0 && (a.useScaleA(), n.push({ type: "float32", data: [d] })), f !== void 0 && (a.useScaleB(), n.push({ type: "float32", data: [f] })), c !== void 0 && a.useActivation(c), w && (a.useCausalMask(), n.push({ type: "int32", data: [z || 0] }));
48
- const P = a.outputShape.length;
49
- r && (m.attrs.originalShape = a.outputShape);
50
- const J = r ?? K.concat([a.outputShape[P - 2], a.outputShape[P - 1]]);
51
- a.setOutputShape(J, k);
52
- const Q = I.runWebGPUProgram(
53
- a,
54
- [G, O],
55
- "packedF16",
56
- n.length > 0 ? n : void 0
57
- );
58
- return G.dispose(), O.dispose(), Q;
1
+ import { Ii as e, In as t, Wr as n, _n as r, br as i, gr as a, hn as o, lt as s, oc as c, w as l } from "../../dist-BewPQWjc.js";
2
+ import { isPackedTensor as u } from "../../utilities/packed.js";
3
+ import { reshape16 as d } from "../reshape16.js";
4
+ import { matMulMul as f } from "../matMulMul.js";
5
+ import { matMulGelu as p } from "../matMulGelu.js";
6
+ import m from "./matMul16_program.js";
7
+ //#region lib/ops/webgpu/matMul16.ts
8
+ function h(e) {
9
+ let { A: h, B: g } = e.inputs, { transposeA: _, transposeB: v, scale: y, activation: b, scaleA: x, scaleB: S, forceOutputShape: C, perm: w, causalMask: T, pastLen: E } = e.attrs, D = e.backend, O = !u(h), k = !u(g);
10
+ if (O && k) {
11
+ let e = x === void 0 ? h : n(h, r(x)), t = S === void 0 ? g : n(g, r(S));
12
+ if (T) throw Error("Causal mask is not supported for unpacked MatMul16.");
13
+ let c;
14
+ if (y !== void 0) c = f(e, t, r(y), _, v);
15
+ else if (b === "gelu") c = p(e, t);
16
+ else if (b === "relu2") {
17
+ let n = a(e, t, _, v), r = s(n);
18
+ n.dispose(), c = o(r), r.dispose();
19
+ } else c = b === "relu" ? s(a(e, t, _, v)) : a(e, t, _, v);
20
+ if (w) if (C) {
21
+ let e = i(c, C);
22
+ c.dispose();
23
+ let t = l(e, w);
24
+ return e.dispose(), t;
25
+ } else {
26
+ let e = l(c, w);
27
+ return c.dispose(), e;
28
+ }
29
+ else if (C) {
30
+ let e = i(c, C);
31
+ return c.dispose(), e;
32
+ } else return c;
33
+ }
34
+ if (O && !k) throw Error("When using mixed precision, A must be packed if B is packed.");
35
+ if (!O && k) throw Error("When using mixed precision, B must be packed if A is packed.");
36
+ let A = h.shape.length, j = g.shape.length, M = h.shape.slice(0, -2), N = g.shape.slice(0, -2), P = c(M), F = c(N), I = t(h.shape.slice(0, -2), g.shape.slice(0, -2)), L = Math.max(P, F), R = h.shape[A - 2], z = g.shape[j - 2], B = h.shape[A - 1] * 2, V = g.shape[j - 1] * 2, H = d(h, [
37
+ P,
38
+ h.shape[A - 2],
39
+ h.shape[A - 1]
40
+ ]), U = d(g, [
41
+ F,
42
+ g.shape[j - 2],
43
+ g.shape[j - 1]
44
+ ]), W = new m(L, R, z, B, V, _, v), G = [];
45
+ y !== void 0 && (W.useScale(), G.push({
46
+ type: "float32",
47
+ data: [y]
48
+ })), x !== void 0 && (W.useScaleA(), G.push({
49
+ type: "float32",
50
+ data: [x]
51
+ })), S !== void 0 && (W.useScaleB(), G.push({
52
+ type: "float32",
53
+ data: [S]
54
+ })), b !== void 0 && W.useActivation(b), T && (W.useCausalMask(), G.push({
55
+ type: "int32",
56
+ data: [E || 0]
57
+ }));
58
+ let K = W.outputShape.length;
59
+ C && (e.attrs.originalShape = W.outputShape);
60
+ let q = C ?? I.concat([W.outputShape[K - 2], W.outputShape[K - 1]]);
61
+ W.setOutputShape(q, w);
62
+ let J = D.runWebGPUProgram(W, [H, U], "packedF16", G.length > 0 ? G : void 0);
63
+ return H.dispose(), U.dispose(), J;
59
64
  }
60
- const te = {
61
- kernelName: "MatMul16",
62
- backendName: "webgpu",
63
- kernelFunc: se
64
- };
65
- V(te);
65
+ e({
66
+ kernelName: "MatMul16",
67
+ backendName: "webgpu",
68
+ kernelFunc: h
69
+ });
70
+ //#endregion
@@ -1,149 +1,119 @@
1
- import { U as f } from "../../index-CUXkjxiT.js";
2
- import { e as p } from "../../webgpu_program-B4HmApL1.js";
3
- class B {
4
- variableNames = ["A", "B"];
5
- outputShape;
6
- shaderKey = "MatMul16TB";
7
- dispatchLayout;
8
- dispatch;
9
- workgroupSize = [8, 8, 1];
10
- // 8x8 threads for 32x32 tile
11
- dimInner;
12
- transposeA = !1;
13
- transposeB = !0;
14
- broadcastBatch = !0;
15
- tileInner = 32;
16
- uniforms;
17
- scale = !1;
18
- scaleA = !1;
19
- scaleB = !1;
20
- activation;
21
- causalMask = !1;
22
- outputComponent;
23
- variableComponents;
24
- outputIndexSnippet;
25
- outputStrideSnippet;
26
- constructor(e, t, o, a, i, r = !1, s = !1) {
27
- if (this.transposeA = r, this.transposeB = s, this.variableComponents = [2, 2], this.outputComponent = 2, this.shaderKey = `MatMul16TB_${t}_${o}_${a}_${i}_${r ? "TA" : ""}${s ? "TB" : ""}`, r) {
28
- if (this.outputShape = [e, a, i / 2], this.dimInner = t, t !== o)
29
- throw new Error("Inner dimensions of A and B must match for MatMul16 transposeA");
30
- } else if (s) {
31
- if (this.outputShape = [e, t, o / 2], this.dimInner = i, i !== a)
32
- throw new Error("Inner dimensions of A and B must match for MatMul16 transposeB");
33
- } else if (this.outputShape = [e, t, i / 2], this.dimInner = a, a !== o)
34
- throw new Error("Inner dimensions of A and B must match for MatMul16");
35
- if (this.dimInner % this.tileInner !== 0)
36
- throw new Error(`Inner dimension ${this.dimInner} must be multiple of ${this.tileInner}`);
37
- if (this.dispatchLayout = { x: [2], y: [1], z: [0] }, this.dispatch = [
38
- Math.ceil(this.outputShape[2] / (this.workgroupSize[0] * 2)),
39
- // 4 unpacked cols per thread = 2 packed cols
40
- Math.ceil(this.outputShape[1] / (this.workgroupSize[1] * 4)),
41
- // 4 rows per thread
42
- this.outputShape[0]
43
- ], i % 32 !== 0)
44
- throw new Error("Head size must be even for MatMul16 transposeB");
45
- if (a % 32 !== 0)
46
- throw new Error("Head size must be even for MatMul16 transposeB");
47
- if (t % 32 !== 0)
48
- throw new Error("Sequence length must be multiple of 32 for MatMul16 transposeB");
49
- if (o % 32 !== 0)
50
- throw new Error("Sequence length must be multiple of 32 for MatMul16 transposeB");
51
- this.outputIndexSnippet = "var idx0 = getOutputIndexFromCoords(vec3<i32>(batch, gRow, gColPacked));", this.outputStrideSnippet = "idx0 = idx0 + uniforms.outShapeStrides[1]; // Next row";
52
- }
53
- addUniform(e) {
54
- this.uniforms ? this.uniforms += `, ${e}` : this.uniforms = e;
55
- }
56
- /* Note: this is done after constructor because it shouldn't affect dispatch */
57
- setOutputShape(e, t) {
58
- const o = f(e), a = f(this.outputShape);
59
- if (o !== a)
60
- throw new Error(`New shape size ${o} must match current size ${a}`);
61
- function i(c, u) {
62
- return [`${c} / ${u}`, `${c} % ${u}`];
63
- }
64
- const r = this.outputShape;
65
- let s = [];
66
- if (e.length === r.length + 1)
67
- if (e[0] * e[1] === r[0])
68
- s = [
69
- ...i("batch", e[1]),
70
- // batch / B2, batch % B2
71
- "gRow",
72
- "gColPacked"
73
- ], this.shaderKey += `_batchSplit_${e[1]}`;
74
- else if (e[e.length - 2] * e[e.length - 1] === r[r.length - 1])
75
- s = [
76
- "batch",
77
- "gRow",
78
- ...i("gColPacked", e[e.length - 1])
79
- // gColPacked / N2, gColPacked % N2
80
- ], this.shaderKey += `_colSplit_${e[e.length - 1]}`;
81
- else
82
- throw new Error("Unsupported output shape split");
83
- else if (e.length === r.length)
84
- s = ["batch", "gRow", "gColPacked"];
85
- else if (e.length === 2 && r[0] === 1)
86
- s = ["gRow", "gColPacked"], this.shaderKey += "_batchRemoved";
87
- else
88
- throw new Error(`Unsupported output shape rank change: ${r.length} -> ${e.length}}`);
89
- let n = [];
90
- if (t) {
91
- if (t.length !== e.length)
92
- throw new Error("Permutation length must match output rank");
93
- n = t.map((c) => s[c]), this.shaderKey += `_perm_${t.join("")}`;
94
- } else
95
- n = s;
96
- const l = n.findIndex((c) => c === "gRow"), h = `vec${e.length}<i32>(${n.join(", ")})`;
97
- this.outputIndexSnippet = `var idx0: i32 = getOutputIndexFromCoords(${h});`, this.outputStrideSnippet = `idx0 = idx0 + uniforms.outShapeStrides${l === 0 ? "" : `[${l}]`}; `, t ? this.outputShape = t.map((c) => e[c]) : this.outputShape = e;
98
- }
99
- useScale() {
100
- this.addUniform("scale: f32"), this.scale = !0, this.shaderKey += "_scaled";
101
- }
102
- useScaleA() {
103
- this.addUniform("scaleA: f32"), this.scaleA = !0, this.shaderKey += "_scaledA";
104
- }
105
- useScaleB() {
106
- this.addUniform("scaleB: f32"), this.scaleB = !0, this.shaderKey += "_scaledB";
107
- }
108
- useActivation(e) {
109
- this.activation = e, this.shaderKey += `_${e}`;
110
- }
111
- useCausalMask() {
112
- this.causalMask = !0, this.addUniform("pastLen: i32"), this.shaderKey += "_causalMask";
113
- }
114
- activationSnippet() {
115
- return this.activation === "gelu" ? `
116
- // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
117
- fn tanhComplete(x: vec4<f32>) -> vec4<f32> {
118
- return vec4<f32>(
119
- select(tanh(x.x), sign(x.x), abs(x.x) > 15.0f),
120
- select(tanh(x.y), sign(x.y), abs(x.y) > 15.0f),
121
- select(tanh(x.z), sign(x.z), abs(x.z) > 15.0f),
122
- select(tanh(x.w), sign(x.w), abs(x.w) > 15.0f),
123
- );
124
- }
125
- fn activation(x : vec4<f32>) -> vec4<f32> {
126
- let x3 = x * x * x;
127
- var inner = fma(vec4<f32>(${0.044715}f), x3, x);
128
- inner = ${0.7978845608028654}f * inner;
129
- inner = tanhComplete(inner);
130
- inner = 0.5f * (1.0f + inner);
131
- return x * inner;
132
- }
133
- ` : this.activation === "relu2" ? `
134
- fn activation(x : vec4<f32>) -> vec4<f32> {
135
- let y = max(x, vec4<f32>(0.0));
136
- return y * y;
137
- }
138
- ` : this.activation === "relu" ? `
139
- fn activation(x : vec4<f32>) -> vec4<f32> {
140
- return max(x, vec4<f32>(0.0));
141
- }
142
- ` : "";
143
- }
144
- /* Transpose when writing to shared memory */
145
- readASnippet() {
146
- const e = `
1
+ import { oc as e } from "../../dist-BewPQWjc.js";
2
+ import { s as t } from "../../webgpu_program-WOyIVMlZ.js";
3
+ //#region lib/ops/webgpu/matMul16_program.ts
4
+ var n = class {
5
+ variableNames = ["A", "B"];
6
+ outputShape;
7
+ shaderKey = "MatMul16TB";
8
+ dispatchLayout;
9
+ dispatch;
10
+ workgroupSize = [
11
+ 8,
12
+ 8,
13
+ 1
14
+ ];
15
+ dimInner;
16
+ transposeA = !1;
17
+ transposeB = !0;
18
+ broadcastBatch = !0;
19
+ tileInner = 32;
20
+ uniforms;
21
+ scale = !1;
22
+ scaleA = !1;
23
+ scaleB = !1;
24
+ activation;
25
+ causalMask = !1;
26
+ outputComponent;
27
+ variableComponents;
28
+ outputIndexSnippet;
29
+ outputStrideSnippet;
30
+ constructor(e, t, n, r, i, a = !1, o = !1) {
31
+ if (this.transposeA = a, this.transposeB = o, this.variableComponents = [2, 2], this.outputComponent = 2, this.shaderKey = `MatMul16TB_${t}_${n}_${r}_${i}_${a ? "TA" : ""}${o ? "TB" : ""}`, a) {
32
+ if (this.outputShape = [
33
+ e,
34
+ r,
35
+ i / 2
36
+ ], this.dimInner = t, t !== n) throw Error("Inner dimensions of A and B must match for MatMul16 transposeA");
37
+ } else if (o) {
38
+ if (this.outputShape = [
39
+ e,
40
+ t,
41
+ n / 2
42
+ ], this.dimInner = i, i !== r) throw Error("Inner dimensions of A and B must match for MatMul16 transposeB");
43
+ } else if (this.outputShape = [
44
+ e,
45
+ t,
46
+ i / 2
47
+ ], this.dimInner = r, r !== n) throw Error("Inner dimensions of A and B must match for MatMul16");
48
+ if (this.dimInner % this.tileInner !== 0) throw Error(`Inner dimension ${this.dimInner} must be multiple of ${this.tileInner}`);
49
+ if (this.dispatchLayout = {
50
+ x: [2],
51
+ y: [1],
52
+ z: [0]
53
+ }, this.dispatch = [
54
+ Math.ceil(this.outputShape[2] / (this.workgroupSize[0] * 2)),
55
+ Math.ceil(this.outputShape[1] / (this.workgroupSize[1] * 4)),
56
+ this.outputShape[0]
57
+ ], i % 32 != 0 || r % 32 != 0) throw Error("Head size must be even for MatMul16 transposeB");
58
+ if (t % 32 != 0 || n % 32 != 0) throw Error("Sequence length must be multiple of 32 for MatMul16 transposeB");
59
+ this.outputIndexSnippet = "var idx0 = getOutputIndexFromCoords(vec3<i32>(batch, gRow, gColPacked));", this.outputStrideSnippet = "idx0 = idx0 + uniforms.outShapeStrides[1]; // Next row";
60
+ }
61
+ addUniform(e) {
62
+ this.uniforms ? this.uniforms += `, ${e}` : this.uniforms = e;
63
+ }
64
+ setOutputShape(t, n) {
65
+ let r = e(t), i = e(this.outputShape);
66
+ if (r !== i) throw Error(`New shape size ${r} must match current size ${i}`);
67
+ function a(e, t) {
68
+ return [`${e} / ${t}`, `${e} % ${t}`];
69
+ }
70
+ let o = this.outputShape, s = [];
71
+ if (t.length === o.length + 1) if (t[0] * t[1] === o[0]) s = [
72
+ ...a("batch", t[1]),
73
+ "gRow",
74
+ "gColPacked"
75
+ ], this.shaderKey += `_batchSplit_${t[1]}`;
76
+ else if (t[t.length - 2] * t[t.length - 1] === o[o.length - 1]) s = [
77
+ "batch",
78
+ "gRow",
79
+ ...a("gColPacked", t[t.length - 1])
80
+ ], this.shaderKey += `_colSplit_${t[t.length - 1]}`;
81
+ else throw Error("Unsupported output shape split");
82
+ else if (t.length === o.length) s = [
83
+ "batch",
84
+ "gRow",
85
+ "gColPacked"
86
+ ];
87
+ else if (t.length === 2 && o[0] === 1) s = ["gRow", "gColPacked"], this.shaderKey += "_batchRemoved";
88
+ else throw Error(`Unsupported output shape rank change: ${o.length} -> ${t.length}}`);
89
+ let c = [];
90
+ if (n) {
91
+ if (n.length !== t.length) throw Error("Permutation length must match output rank");
92
+ c = n.map((e) => s[e]), this.shaderKey += `_perm_${n.join("")}`;
93
+ } else c = s;
94
+ let l = c.findIndex((e) => e === "gRow"), u = `vec${t.length}<i32>(${c.join(", ")})`;
95
+ this.outputIndexSnippet = `var idx0: i32 = getOutputIndexFromCoords(${u});`, this.outputStrideSnippet = `idx0 = idx0 + uniforms.outShapeStrides${l === 0 ? "" : `[${l}]`}; `, n ? this.outputShape = n.map((e) => t[e]) : this.outputShape = t;
96
+ }
97
+ useScale() {
98
+ this.addUniform("scale: f32"), this.scale = !0, this.shaderKey += "_scaled";
99
+ }
100
+ useScaleA() {
101
+ this.addUniform("scaleA: f32"), this.scaleA = !0, this.shaderKey += "_scaledA";
102
+ }
103
+ useScaleB() {
104
+ this.addUniform("scaleB: f32"), this.scaleB = !0, this.shaderKey += "_scaledB";
105
+ }
106
+ useActivation(e) {
107
+ this.activation = e, this.shaderKey += `_${e}`;
108
+ }
109
+ useCausalMask() {
110
+ this.causalMask = !0, this.addUniform("pastLen: i32"), this.shaderKey += "_causalMask";
111
+ }
112
+ activationSnippet() {
113
+ return this.activation === "gelu" ? "\n // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved\n fn tanhComplete(x: vec4<f32>) -> vec4<f32> {\n return vec4<f32>(\n select(tanh(x.x), sign(x.x), abs(x.x) > 15.0f),\n select(tanh(x.y), sign(x.y), abs(x.y) > 15.0f),\n select(tanh(x.z), sign(x.z), abs(x.z) > 15.0f),\n select(tanh(x.w), sign(x.w), abs(x.w) > 15.0f),\n );\n }\n fn activation(x : vec4<f32>) -> vec4<f32> {\n let x3 = x * x * x;\n var inner = fma(vec4<f32>(0.044715f), x3, x);\n inner = 0.7978845608028654f * inner;\n inner = tanhComplete(inner);\n inner = 0.5f * (1.0f + inner);\n return x * inner;\n }\n " : this.activation === "relu2" ? "\n fn activation(x : vec4<f32>) -> vec4<f32> {\n let y = max(x, vec4<f32>(0.0));\n return y * y;\n }\n " : this.activation === "relu" ? "\n fn activation(x : vec4<f32>) -> vec4<f32> {\n return max(x, vec4<f32>(0.0));\n }\n " : "";
114
+ }
115
+ readASnippet() {
116
+ let e = `
147
117
  var col = i32(localId.x);
148
118
  var row = i32(localId.y) * 4;
149
119
  var packedA: vec2<i32> = A[offsetA + row * strideA + col];
@@ -172,7 +142,7 @@ class B {
172
142
  ${this.scaleA ? "Arow3 = Arow3 * uniforms.scaleA;" : ""}
173
143
  ${this.scaleA ? "Arow4 = Arow4 * uniforms.scaleA;" : ""}
174
144
  `;
175
- return this.transposeA ? `{
145
+ return this.transposeA ? `{
176
146
  ${e}
177
147
  mm_Asub[row][col] = Arow1;
178
148
  mm_Asub[row + 1][col] = Arow2;
@@ -189,10 +159,9 @@ class B {
189
159
  mm_Asub[col + 2][row] = vec4<f32>(Arow1.z, Arow2.z, Arow3.z, Arow4.z);
190
160
  mm_Asub[col + 3][row] = vec4<f32>(Arow1.w, Arow2.w, Arow3.w, Arow4.w);
191
161
  }`;
192
- }
193
- /* Transpose when writing to shared memory */
194
- readBSnippet() {
195
- const e = `
162
+ }
163
+ readBSnippet() {
164
+ let e = `
196
165
  var col = i32(localId.x);
197
166
  var row = i32(localId.y) * 4;
198
167
  var packedB: vec2<i32> = B[offsetB + row * strideB + col];
@@ -221,7 +190,7 @@ class B {
221
190
  ${this.scaleB ? "Brow3 = Brow3 * uniforms.scaleB;" : ""}
222
191
  ${this.scaleB ? "Brow4 = Brow4 * uniforms.scaleB;" : ""}
223
192
  `;
224
- return this.transposeB ? `{
193
+ return this.transposeB ? `{
225
194
  ${e}
226
195
 
227
196
  col = i32(localId.x) * 4;
@@ -238,47 +207,46 @@ class B {
238
207
  mm_Bsub[row + 2][col] = Brow3;
239
208
  mm_Bsub[row + 3][col] = Brow4;
240
209
  }`;
241
- }
242
- baseIndexSnippets() {
243
- const e = `
210
+ }
211
+ baseIndexSnippets() {
212
+ let e = "";
213
+ e = this.transposeB ? "let baseB = getIndexFromCoords3D(vec3<i32>(batchB, globalColStart, 0), vec3<i32>(uniforms.bShape.x, uniforms.bShape.y, strideB));" : "let baseB = getIndexFromCoords3D(vec3<i32>(batchB, 0, globalColStart / 4), vec3<i32>(uniforms.bShape.x, uniforms.bShape.y, strideB));";
214
+ let t = "";
215
+ return t = this.transposeA ? "let baseA = getIndexFromCoords3D(vec3<i32>(batchA, 0, globalRowStart / 4), vec3<i32>(uniforms.aShape.x, uniforms.aShape.y, strideA));" : "let baseA = getIndexFromCoords3D(vec3<i32>(batchA, globalRowStart, 0), vec3<i32>(uniforms.aShape.x, uniforms.aShape.y, strideA));", `
216
+
244
217
  let strideA = uniforms.aShape.z / 2;
245
218
  let strideB = uniforms.bShape.z / 2;
246
- `;
247
- let t = "";
248
- this.transposeB ? t = "let baseB = getIndexFromCoords3D(vec3<i32>(batchB, globalColStart, 0), vec3<i32>(uniforms.bShape.x, uniforms.bShape.y, strideB));" : t = "let baseB = getIndexFromCoords3D(vec3<i32>(batchB, 0, globalColStart / 4), vec3<i32>(uniforms.bShape.x, uniforms.bShape.y, strideB));";
249
- let o = "";
250
- return this.transposeA ? o = "let baseA = getIndexFromCoords3D(vec3<i32>(batchA, 0, globalRowStart / 4), vec3<i32>(uniforms.aShape.x, uniforms.aShape.y, strideA));" : o = "let baseA = getIndexFromCoords3D(vec3<i32>(batchA, globalRowStart, 0), vec3<i32>(uniforms.aShape.x, uniforms.aShape.y, strideA));", `
251
- ${e}
252
- ${o}
219
+
253
220
  ${t}
221
+ ${e}
254
222
  `;
255
- }
256
- offsetSnippets() {
257
- let e = "";
258
- this.transposeA ? e = "let offsetA = baseA + kStart * strideA;" : e = "let offsetA = baseA + kStart / 4;";
259
- let t = "";
260
- return this.transposeB ? t = "let offsetB = baseB + kStart / 4;" : t = "let offsetB = baseB + kStart * strideB;", `
223
+ }
224
+ offsetSnippets() {
225
+ let e = "";
226
+ e = this.transposeA ? "let offsetA = baseA + kStart * strideA;" : "let offsetA = baseA + kStart / 4;";
227
+ let t = "";
228
+ return t = this.transposeB ? "let offsetB = baseB + kStart / 4;" : "let offsetB = baseB + kStart * strideB;", `
261
229
  ${e}
262
230
  ${t}
263
231
  `;
264
- }
265
- getUserCode() {
266
- const e = this.transposeA, t = this.tileInner, o = this.workgroupSize[1] * 4, a = this.workgroupSize[0] * 4, i = e ? o : t, r = e ? t : o, s = this.dimInner, n = Math.ceil(s / t);
267
- return `
268
- var<workgroup> mm_Asub : array<array<vec4<f32>, ${i / 4 + (this.transposeA ? 0 : 1)}>, ${r}>;
269
- var<workgroup> mm_Bsub : array<array<vec4<f32>, ${a / 4 + (this.transposeB ? 1 : 0)}>, ${t}>;
232
+ }
233
+ getUserCode() {
234
+ let e = this.transposeA, n = this.tileInner, r = this.workgroupSize[1] * 4, i = this.workgroupSize[0] * 4, a = e ? r : n, o = e ? n : r, s = this.dimInner, c = Math.ceil(s / n);
235
+ return `
236
+ var<workgroup> mm_Asub : array<array<vec4<f32>, ${a / 4 + +!this.transposeA}>, ${o}>;
237
+ var<workgroup> mm_Bsub : array<array<vec4<f32>, ${i / 4 + +!!this.transposeB}>, ${n}>;
270
238
 
271
239
  ${this.activation ? this.activationSnippet() : ""}
272
240
 
273
- ${p()} {
241
+ ${t()} {
274
242
  let batch = i32(globalId.z);
275
243
  let batchA = ${this.broadcastBatch ? "batch % uniforms.aShape[0]" : "batch"};
276
244
  let batchB = ${this.broadcastBatch ? "batch % uniforms.bShape[0]" : "batch"};
277
245
  var kStart = 0;
278
246
  let localRow = i32(localId.y);
279
247
  let localCol = i32(localId.x);
280
- let globalRowStart = i32(workgroupId.y) * ${o};
281
- let globalColStart = i32(workgroupId.x) * ${a};
248
+ let globalRowStart = i32(workgroupId.y) * ${r};
249
+ let globalColStart = i32(workgroupId.x) * ${i};
282
250
 
283
251
  // 4 rows x 4 cols accumulator
284
252
  // acc[i] holds row i (4 cols)
@@ -288,16 +256,16 @@ class B {
288
256
 
289
257
  ${this.baseIndexSnippets()}
290
258
 
291
- for (var t = 0; t < ${n}; t++) {
259
+ for (var t = 0; t < ${c}; t++) {
292
260
  ${this.offsetSnippets()}
293
261
 
294
262
  ${this.readASnippet()}
295
263
  ${this.readBSnippet()}
296
264
 
297
- kStart = kStart + ${t};
265
+ kStart = kStart + ${n};
298
266
  workgroupBarrier();
299
267
 
300
- for (var k = 0; k < ${t}; k++) {
268
+ for (var k = 0; k < ${n}; k++) {
301
269
  // Load 4 columns of B as a vec4
302
270
  let bVec = mm_Bsub[k][localCol];
303
271
  let aVec = mm_Asub[k][localRow];
@@ -316,14 +284,7 @@ class B {
316
284
 
317
285
  ${this.outputIndexSnippet}
318
286
  for (var i = 0; i < 4; i = i + 1) {
319
- ${this.causalMask ? `
320
- // Causal Masking: mask if col > row + pastLen
321
- let r = gRow + i;
322
- let cBase = gColPacked * 2;
323
- let cVec = vec4<i32>(cBase, cBase + 1, cBase + 2, cBase + 3);
324
- let mask = cVec > vec4<i32>(r + uniforms.pastLen);
325
- acc[i] = select(acc[i], vec4<f32>(-uniforms.INFINITY), mask);
326
- ` : ""}
287
+ ${this.causalMask ? "\n // Causal Masking: mask if col > row + pastLen\n let r = gRow + i;\n let cBase = gColPacked * 2;\n let cVec = vec4<i32>(cBase, cBase + 1, cBase + 2, cBase + 3);\n let mask = cVec > vec4<i32>(r + uniforms.pastLen);\n acc[i] = select(acc[i], vec4<f32>(-uniforms.INFINITY), mask);\n " : ""}
327
288
 
328
289
  ${this.activation ? "acc[i] = activation(acc[i]);" : ""}
329
290
  ${this.scale ? "acc[i] = acc[i] * uniforms.scale;" : ""}
@@ -336,8 +297,7 @@ class B {
336
297
  }
337
298
  }
338
299
  `;
339
- }
340
- }
341
- export {
342
- B as default
300
+ }
343
301
  };
302
+ //#endregion
303
+ export { n as default };
@@ -1,13 +1,14 @@
1
- import { c as t } from "../../index-CUXkjxiT.js";
2
- import { BinaryOpScalarProgram as s, BinaryOpProgram as c } from "./utils/binary_op.js";
3
- import { B as a } from "../../binary_op_util-pKXltfxI.js";
4
- function m(n) {
5
- const { a: e, b: r } = n.inputs, o = n.backend, p = r.shape.length === 0 ? new s(a.MUL, e.shape) : new c(a.MUL, e.shape, r.shape);
6
- return o.runWebGPUProgram(p, [e, r], "packedF16");
1
+ import { Ii as e } from "../../dist-BewPQWjc.js";
2
+ import { t } from "../../binary_op_util-CrYk9LXL.js";
3
+ import { BinaryOpProgram as n, BinaryOpScalarProgram as r } from "./utils/binary_op.js";
4
+ //#region lib/ops/webgpu/mul16.ts
5
+ function i(e) {
6
+ let { a: i, b: a } = e.inputs, o = e.backend, s = a.shape.length === 0 ? new r(t.MUL, i.shape) : new n(t.MUL, i.shape, a.shape);
7
+ return o.runWebGPUProgram(s, [i, a], "packedF16");
7
8
  }
8
- const i = {
9
- kernelName: "Mul16",
10
- backendName: "webgpu",
11
- kernelFunc: m
12
- };
13
- t(i);
9
+ e({
10
+ kernelName: "Mul16",
11
+ backendName: "webgpu",
12
+ kernelFunc: i
13
+ });
14
+ //#endregion