whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -8,8 +8,6 @@
8
8
  #include "ggml-backend-impl.h"
9
9
  #include "ggml-impl.h"
10
10
  #include "ggml-webgpu-shader-lib.hpp"
11
- #include "ggml-wgsl-shaders.hpp"
12
- #include "pre_wgsl.hpp"
13
11
 
14
12
  #ifdef __EMSCRIPTEN__
15
13
  # include <emscripten/emscripten.h>
@@ -21,16 +19,30 @@
21
19
  #include <condition_variable>
22
20
  #include <cstdint>
23
21
  #include <cstring>
24
- #include <iostream>
22
+ #ifdef GGML_WEBGPU_GPU_PROFILE
23
+ # include <iomanip>
24
+ #endif
25
+ #if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
26
+ # include <iostream>
27
+ #endif
25
28
  #include <map>
29
+ #include <memory>
26
30
  #include <mutex>
27
31
  #include <optional>
28
32
  #include <string>
33
+ #include <utility>
29
34
  #include <vector>
30
35
 
31
36
  #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
32
37
  #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
33
38
 
39
+ // Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
40
+ // Assumes that the total number of workgroups does not exceed max_per_dim^2.
41
+ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
42
+ wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
43
+ wg_x = CEIL_DIV(total_wg, wg_y);
44
+ }
45
+
34
46
  #ifdef GGML_WEBGPU_DEBUG
35
47
  # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
36
48
  # define WEBGPU_DEBUG_BUF_ELEMS 512
@@ -47,7 +59,6 @@
47
59
  double cpu_total_time_##id = \
48
60
  std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
49
61
  (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
50
-
51
62
  // fine-grained timing (not included in totals)
52
63
  # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
53
64
 
@@ -64,56 +75,34 @@
64
75
  #endif // GGML_WEBGPU_CPU_PROFILE
65
76
 
66
77
  #ifdef GGML_WEBGPU_GPU_PROFILE
67
- # define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
78
+ # define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
68
79
  # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
69
80
  #endif
70
81
 
71
82
  /* Constants */
72
83
 
73
- // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
74
- #define WEBGPU_MAX_WG_SIZE 288
75
-
76
- #define WEBGPU_MUL_MAT_WG_SIZE 256
77
- #define WEBGPU_NUM_PARAM_BUFS 32u
78
- #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
84
+ #define WEBGPU_NUM_PARAM_BUFS 96u
85
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
79
86
  #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
80
- // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
81
- #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
87
+ // Maximum number of in-flight submissions per-thread, to avoid exhausting the
88
+ // parameter buffer pool
89
+ #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
82
90
  #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
83
- #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
84
91
  #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
85
- #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
92
+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
86
93
 
87
- // For operations which process a row in parallel, this seems like a reasonable default
94
+ // For operations which process a row in parallel, this seems like a reasonable
95
+ // default
88
96
  #define WEBGPU_ROW_SPLIT_WG_SIZE 64
89
97
 
90
- // Matrix multiplication parameters
91
-
92
- // Register tiling parameters
93
- #define WEBGPU_MUL_MAT_TILE_M 8
94
- #define WEBGPU_MUL_MAT_TILE_N 8
95
- #define WEBGPU_MUL_MAT_WG_SIZE_M 8
96
- #define WEBGPU_MUL_MAT_WG_SIZE_N 8
97
- #define WEBGPU_MUL_MAT_TILE_K 32
98
-
99
- // Subgroup matrix parameters
100
- // The number of subgroups in the M dimension
101
- #define WEBGPU_MUL_MAT_SUBGROUP_M 2
102
- // The number of subgroups in the N dimension
103
- #define WEBGPU_MUL_MAT_SUBGROUP_N 2
104
- // The number of subgroup matrices each subgroup accumulates over
105
- #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
106
- #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
107
-
108
- // Matrix-vector multiplication parameters
109
- #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
110
- // Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
111
- #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
112
- #define WEBGPU_MUL_MAT_VEC_TILE_K 256
98
+ // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
99
+ // implementations so this can be removed, necessary only for get_rows right now
100
+ #define WEBGPU_MAX_WG_SIZE 288
113
101
 
114
102
  /* End Constants */
115
103
 
116
- // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
104
+ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to
105
+ // their locations.
117
106
  static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
118
107
 
119
108
  // Always returns the base offset of a tensor, regardless of views.
@@ -133,47 +122,70 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
133
122
  wgpu::BufferUsage usage,
134
123
  const char * label);
135
124
 
136
- struct webgpu_pool_bufs {
137
- wgpu::Buffer host_buf;
138
- wgpu::Buffer dev_buf;
139
- };
140
-
141
- // The futures to wait on for a single queue submission
142
- struct webgpu_submission_futures {
143
- std::vector<wgpu::FutureWaitInfo> futures;
144
- };
145
-
146
125
  // Holds a pool of parameter buffers for WebGPU operations
147
126
  struct webgpu_buf_pool {
148
- std::vector<webgpu_pool_bufs> free;
149
-
150
- std::mutex mutex;
151
-
127
+ std::vector<wgpu::Buffer> free;
128
+
129
+ // The pool must be synchronized because
130
+ // 1. The memset pool is shared globally by every ggml buffer,
131
+ // since allocating a pool per ggml buffer would consume too much memory.
132
+ // 2. For the per-thread buffer pools in webgpu_context,
133
+ // buffers are allocated and freed in Dawn callbacks,
134
+ // which can run on a different thread than the calling thread.
135
+ std::mutex mutex;
152
136
  std::condition_variable cv;
137
+ size_t cur_pool_size;
138
+ size_t max_pool_size;
139
+ wgpu::Device device;
140
+ wgpu::BufferUsage dev_buf_usage;
141
+ size_t buf_size;
142
+ bool should_grow;
153
143
 
154
144
  void init(wgpu::Device device,
155
145
  int num_bufs,
156
146
  size_t buf_size,
157
147
  wgpu::BufferUsage dev_buf_usage,
158
- wgpu::BufferUsage host_buf_usage) {
148
+ bool should_grow = false,
149
+ size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
150
+ this->max_pool_size = max_pool_size;
151
+ this->cur_pool_size = num_bufs;
152
+ this->device = device;
153
+ this->dev_buf_usage = dev_buf_usage;
154
+ this->buf_size = buf_size;
155
+ this->should_grow = should_grow;
159
156
  for (int i = 0; i < num_bufs; i++) {
160
- wgpu::Buffer host_buf;
161
157
  wgpu::Buffer dev_buf;
162
- ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
163
158
  ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
164
- free.push_back({ host_buf, dev_buf });
159
+ free.push_back(dev_buf);
165
160
  }
166
161
  }
167
162
 
168
- webgpu_pool_bufs alloc_bufs() {
163
+ wgpu::Buffer alloc_bufs() {
169
164
  std::unique_lock<std::mutex> lock(mutex);
165
+ if (!free.empty()) {
166
+ wgpu::Buffer buf = free.back();
167
+ free.pop_back();
168
+ return buf;
169
+ }
170
+
171
+ // Try growing the pool if no free buffers
172
+ if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
173
+ cur_pool_size++;
174
+ wgpu::Buffer dev_buf;
175
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
176
+
177
+ if (!dev_buf) {
178
+ GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
179
+ }
180
+ return dev_buf;
181
+ }
170
182
  cv.wait(lock, [this] { return !free.empty(); });
171
- webgpu_pool_bufs bufs = free.back();
183
+ wgpu::Buffer buf = free.back();
172
184
  free.pop_back();
173
- return bufs;
185
+ return buf;
174
186
  }
175
187
 
176
- void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
188
+ void free_bufs(std::vector<wgpu::Buffer> bufs) {
177
189
  std::lock_guard<std::mutex> lock(mutex);
178
190
  free.insert(free.end(), bufs.begin(), bufs.end());
179
191
  cv.notify_all();
@@ -181,12 +193,15 @@ struct webgpu_buf_pool {
181
193
 
182
194
  void cleanup() {
183
195
  std::lock_guard<std::mutex> lock(mutex);
184
- for (auto & bufs : free) {
185
- bufs.host_buf.Destroy();
186
- bufs.dev_buf.Destroy();
196
+ for (auto & buf : free) {
197
+ if (buf) {
198
+ buf.Destroy();
199
+ }
187
200
  }
188
201
  free.clear();
189
202
  }
203
+
204
+ ~webgpu_buf_pool() { this->cleanup(); }
190
205
  };
191
206
 
192
207
  #ifdef GGML_WEBGPU_GPU_PROFILE
@@ -248,188 +263,155 @@ struct webgpu_gpu_profile_buf_pool {
248
263
  }
249
264
  free.clear();
250
265
  }
251
- };
252
- #endif
253
266
 
254
- struct webgpu_pipeline {
255
- wgpu::ComputePipeline pipeline;
256
- std::string name;
257
- void * context = nullptr;
267
+ ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
258
268
  };
269
+ #endif
259
270
 
260
271
  struct webgpu_command {
261
- wgpu::CommandBuffer commands;
262
- webgpu_pool_bufs params_bufs;
263
- std::optional<webgpu_pool_bufs> set_rows_error_bufs;
272
+ uint32_t num_kernels;
273
+ wgpu::CommandBuffer commands;
274
+ std::vector<wgpu::Buffer> params_bufs;
264
275
  #ifdef GGML_WEBGPU_GPU_PROFILE
265
276
  webgpu_gpu_profile_bufs timestamp_query_bufs;
266
277
  std::string pipeline_name;
267
278
  #endif
268
279
  };
269
280
 
270
- struct flash_attn_pipeline_key {
271
- int q_type;
272
- int kv_type;
273
- int dst_type;
274
- uint32_t head_dim_qk;
275
- uint32_t head_dim_v;
276
- bool kv_direct;
277
- bool has_mask;
278
- bool has_sinks;
279
- bool uses_logit_softcap;
280
-
281
- bool operator==(const flash_attn_pipeline_key & other) const {
282
- return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
283
- head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
284
- has_mask == other.has_mask && has_sinks == other.has_sinks &&
285
- uses_logit_softcap == other.uses_logit_softcap;
286
- }
287
- };
281
+ struct webgpu_capabilities {
282
+ wgpu::Limits limits;
283
+ bool supports_subgroup_matrix = false;
288
284
 
289
- // Same hash combine function as in boost
290
- template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
291
- seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
292
- }
293
-
294
- struct flash_attn_pipeline_key_hash {
295
- size_t operator()(const flash_attn_pipeline_key & key) const {
296
- size_t seed = 0;
297
- ggml_webgpu_hash_combine(seed, key.q_type);
298
- ggml_webgpu_hash_combine(seed, key.kv_type);
299
- ggml_webgpu_hash_combine(seed, key.dst_type);
300
- ggml_webgpu_hash_combine(seed, key.head_dim_qk);
301
- ggml_webgpu_hash_combine(seed, key.head_dim_v);
302
- ggml_webgpu_hash_combine(seed, key.kv_direct);
303
- ggml_webgpu_hash_combine(seed, key.has_mask);
304
- ggml_webgpu_hash_combine(seed, key.has_sinks);
305
- ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
306
- return seed;
307
- }
285
+ uint32_t sg_mat_m = 0;
286
+ uint32_t sg_mat_n = 0;
287
+ uint32_t sg_mat_k = 0;
288
+
289
+ uint32_t subgroup_size = 0;
290
+ uint32_t max_subgroup_size = 0;
291
+ size_t memset_bytes_per_thread;
308
292
  };
309
293
 
310
- // All the base objects needed to run operations on a WebGPU device
311
- struct webgpu_context_struct {
294
+ // Stores global webgpu members
295
+ struct webgpu_global_context_struct {
312
296
  wgpu::Instance instance;
313
297
  wgpu::Adapter adapter;
314
298
  wgpu::Device device;
315
299
  wgpu::Queue queue;
316
- wgpu::Limits limits;
317
300
 
318
- uint32_t max_subgroup_size;
301
+ webgpu_capabilities capabilities;
302
+ // Shared buffer to move data from device to host
303
+ wgpu::Buffer get_tensor_staging_buf;
304
+ // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
305
+ std::recursive_mutex mutex;
319
306
 
320
- bool supports_subgroup_matrix = false;
321
- uint32_t sg_mat_m;
322
- uint32_t sg_mat_n;
323
- uint32_t sg_mat_k;
307
+ webgpu_buf_pool memset_buf_pool;
308
+ std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
324
309
 
325
- std::recursive_mutex mutex;
326
- std::atomic_uint inflight_threads = 0;
310
+ #ifdef GGML_WEBGPU_CPU_PROFILE
311
+ // Profiling: labeled CPU time in ms (total)
312
+ std::unordered_map<std::string, double> cpu_time_ms;
313
+ // Profiling: detailed CPU time in ms
314
+ std::unordered_map<std::string, double> cpu_detail_ms;
315
+ #endif
327
316
 
328
- webgpu_buf_pool param_buf_pool;
329
- webgpu_buf_pool set_rows_error_buf_pool;
317
+ #ifdef GGML_WEBGPU_GPU_PROFILE
318
+ // Profiling: per-shader GPU time in ms
319
+ std::unordered_map<std::string, double> shader_gpu_time_ms;
320
+ // Profiling: pool of timestamp query buffers (one per operation)
321
+ webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
322
+ #endif
323
+
324
+ #ifdef GGML_WEBGPU_DEBUG
325
+ wgpu::Buffer debug_host_buf;
326
+ wgpu::Buffer debug_dev_buf;
327
+ #endif
328
+
329
+ ~webgpu_global_context_struct() {
330
+ if (this->get_tensor_staging_buf) {
331
+ this->get_tensor_staging_buf.Destroy();
332
+ this->get_tensor_staging_buf = nullptr;
333
+ }
334
+ #ifdef GGML_WEBGPU_DEBUG
335
+ if (this->debug_host_buf) {
336
+ this->debug_host_buf.Destroy();
337
+ this->debug_host_buf = nullptr;
338
+ }
339
+ if (this->debug_dev_buf) {
340
+ this->debug_dev_buf.Destroy();
341
+ this->debug_dev_buf = nullptr;
342
+ }
343
+ #endif
344
+ }
345
+ };
330
346
 
331
- pre_wgsl::Preprocessor p;
347
+ typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
332
348
 
333
- std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
349
+ struct webgpu_submission {
350
+ wgpu::FutureWaitInfo submit_done;
351
+ #ifdef GGML_WEBGPU_GPU_PROFILE
352
+ std::vector<wgpu::FutureWaitInfo> profile_futures;
353
+ #endif
354
+ };
334
355
 
335
- std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
336
- std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
337
- mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
356
+ // All the base objects needed to run operations on a WebGPU device
357
+ struct webgpu_context_struct {
358
+ // Points to global instances owned by ggml_backend_webgpu_reg_context
359
+ webgpu_global_context global_ctx;
338
360
 
339
- std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
361
+ std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
340
362
 
341
- std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
342
- std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
363
+ webgpu_buf_pool param_buf_pool;
364
+ wgpu::Buffer set_rows_dev_error_buf;
365
+ wgpu::Buffer set_rows_host_error_buf;
343
366
 
344
367
  std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
345
- std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
346
- std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
347
- std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
348
- std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
349
368
 
350
369
  std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
351
370
  std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
352
371
  std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
353
- std::map<int, webgpu_pipeline> scale_pipelines; // inplace
372
+
354
373
  std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
355
- std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> unary_pipelines; // unary_op, type, inplace
356
374
 
357
375
  size_t memset_bytes_per_thread;
358
-
359
- // Staging buffer for reading data from the GPU
360
- wgpu::Buffer get_tensor_staging_buf;
361
-
362
- #ifdef GGML_WEBGPU_DEBUG
363
- wgpu::Buffer debug_host_buf;
364
- wgpu::Buffer debug_dev_buf;
365
- #endif
366
-
367
- #ifdef GGML_WEBGPU_CPU_PROFILE
368
- // Profiling: labeled CPU time in ms (total)
369
- std::unordered_map<std::string, double> cpu_time_ms;
370
- // Profiling: detailed CPU time in ms
371
- std::unordered_map<std::string, double> cpu_detail_ms;
372
- #endif
373
-
374
- #ifdef GGML_WEBGPU_GPU_PROFILE
375
- // Profiling: per-shader GPU time in ms
376
- std::unordered_map<std::string, double> shader_gpu_time_ms;
377
- // Profiling: pool of timestamp query buffers (one per operation)
378
- webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
379
- #endif
380
376
  };
381
377
 
382
378
  typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
383
379
 
380
+ // Metadata required for the ggml backend registration/discovery interface
384
381
  struct ggml_backend_webgpu_reg_context {
385
- webgpu_context webgpu_ctx;
386
- size_t device_count;
387
- const char * name;
382
+ // Since the Instance is a global entrypoint into the WebGPU API, it lives here
383
+ webgpu_global_context webgpu_global_ctx;
384
+ size_t device_count;
385
+ const char * name;
388
386
  };
389
387
 
388
+ // Per-device struct for the global logical device interface
390
389
  struct ggml_backend_webgpu_device_context {
391
- webgpu_context webgpu_ctx;
392
- std::string device_name;
393
- std::string device_desc;
390
+ webgpu_global_context webgpu_global_ctx;
391
+ std::string device_name;
392
+ std::string device_desc;
394
393
  };
395
394
 
395
+ // Per-thread data required to actually run WebGPU operations in a backend instance
396
396
  struct ggml_backend_webgpu_context {
397
397
  webgpu_context webgpu_ctx;
398
398
  std::string name;
399
399
  };
400
400
 
401
+ // Per-thread data related to buffers
401
402
  struct ggml_backend_webgpu_buffer_context {
402
- webgpu_context webgpu_ctx;
403
- wgpu::Buffer buffer;
404
- std::string label;
403
+ wgpu::Buffer buffer;
404
+ std::string label;
405
+ webgpu_global_context global_ctx;
405
406
 
406
- ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
407
- webgpu_ctx(std::move(ctx)),
407
+ ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
408
408
  buffer(std::move(buf)),
409
- label(std::move(lbl)) {}
409
+ label(std::move(lbl)),
410
+ global_ctx(std::move(global_ctx_)) {}
410
411
  };
411
412
 
412
413
  /* WebGPU object initializations */
413
414
 
414
- // Process a WGSL shader string, replacing tokens of the form {{KEY}} with
415
- // the corresponding values provided in `repls`.
416
- static std::string ggml_webgpu_process_shader_repls(const char * src,
417
- const std::map<std::string, std::string> & repls) {
418
- if (!src) {
419
- return std::string();
420
- }
421
- std::string s = src;
422
- for (const auto & kv : repls) {
423
- std::string token = "{{" + kv.first + "}}";
424
- size_t pos = 0;
425
- while ((pos = s.find(token, pos)) != std::string::npos) {
426
- s.replace(pos, token.length(), kv.second);
427
- pos += kv.second.length();
428
- }
429
- }
430
- return s;
431
- }
432
-
433
415
  static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
434
416
  const char * shader_code,
435
417
  const char * label,
@@ -473,44 +455,113 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
473
455
 
474
456
  /** WebGPU Actions */
475
457
 
458
+ static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
459
+ switch (status) {
460
+ case wgpu::WaitStatus::Success:
461
+ return true;
462
+ case wgpu::WaitStatus::TimedOut:
463
+ if (allow_timeout) {
464
+ return false;
465
+ }
466
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
467
+ return false;
468
+ case wgpu::WaitStatus::Error:
469
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
470
+ return false;
471
+ default:
472
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
473
+ return false;
474
+ }
475
+ }
476
+
477
+ #ifdef GGML_WEBGPU_GPU_PROFILE
478
+ static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
479
+ futures.erase(std::remove_if(futures.begin(), futures.end(),
480
+ [](const wgpu::FutureWaitInfo & info) { return info.completed; }),
481
+ futures.end());
482
+ }
483
+
484
+ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
485
+ std::vector<wgpu::FutureWaitInfo> & futures,
486
+ bool block) {
487
+ if (futures.empty()) {
488
+ return;
489
+ }
490
+
491
+ uint64_t timeout_ms = block ? UINT64_MAX : 0;
492
+ if (block) {
493
+ while (!futures.empty()) {
494
+ auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
495
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
496
+ ggml_backend_webgpu_erase_completed_futures(futures);
497
+ }
498
+ }
499
+ } else {
500
+ auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
501
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
502
+ ggml_backend_webgpu_erase_completed_futures(futures);
503
+ }
504
+ }
505
+ }
506
+ #endif
507
+
476
508
  // Wait for the queue to finish processing all submitted work
477
- static void ggml_backend_webgpu_wait(webgpu_context & ctx,
478
- std::vector<webgpu_submission_futures> & futures,
479
- bool block = true) {
480
- // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
481
- // inflight_max may be 0, meaning that we must wait on all futures.
482
- uint64_t timeout_ms = block ? UINT64_MAX : 0;
483
- uint32_t inflight_threads = ctx->inflight_threads;
484
- uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
485
- while (futures.size() >= inflight_max && futures.size() > 0) {
486
- ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
487
- futures.erase(futures.begin());
488
- }
489
- size_t i = 0;
490
- while (i < futures.size()) {
491
- auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
492
- switch (waitStatus) {
493
- case wgpu::WaitStatus::Success:
494
- futures.erase(futures.begin() + i);
495
- break;
496
- case wgpu::WaitStatus::TimedOut:
497
- i++;
498
- break;
499
- case wgpu::WaitStatus::Error:
500
- GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
501
- break;
502
- default:
503
- GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
504
- break;
509
+ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
510
+ std::vector<webgpu_submission> & subs,
511
+ bool block = true) {
512
+ // If we have too many in-flight submissions, wait on the oldest one first.
513
+ if (subs.empty()) {
514
+ return;
515
+ }
516
+ while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
517
+ auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
518
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
519
+ #ifdef GGML_WEBGPU_GPU_PROFILE
520
+ ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
521
+ #endif
522
+ subs.erase(subs.begin());
523
+ }
524
+ }
525
+
526
+ if (subs.empty()) {
527
+ return;
528
+ }
529
+
530
+ if (block) {
531
+ for (auto & sub : subs) {
532
+ while (!sub.submit_done.completed) {
533
+ auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
534
+ ggml_backend_webgpu_handle_wait_status(waitStatus);
535
+ }
536
+ #ifdef GGML_WEBGPU_GPU_PROFILE
537
+ ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
538
+ #endif
539
+ }
540
+ subs.clear();
541
+ } else {
542
+ // Poll each submit future once and remove completed submissions.
543
+ for (auto sub = subs.begin(); sub != subs.end();) {
544
+ auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
545
+ ggml_backend_webgpu_handle_wait_status(waitStatus, true);
546
+ #ifdef GGML_WEBGPU_GPU_PROFILE
547
+ ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
548
+ if (sub->submit_done.completed && sub->profile_futures.empty()) {
549
+ #else
550
+ if (sub->submit_done.completed) {
551
+ #endif
552
+ sub = subs.erase(sub);
553
+ } else {
554
+ ++sub;
555
+ }
505
556
  }
506
557
  }
507
558
  }
508
559
 
509
- static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
510
- wgpu::Buffer & buffer,
511
- wgpu::MapMode mode,
512
- size_t offset,
513
- size_t size) {
560
+ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
561
+ wgpu::Buffer & buffer,
562
+ wgpu::MapMode mode,
563
+ size_t offset,
564
+ size_t size) {
514
565
  ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
515
566
  [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
516
567
  if (status != wgpu::MapAsyncStatus::Success) {
@@ -525,7 +576,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
525
576
  // This function adds debugging information to shaders, as WebGPU does not support printing directly.
526
577
  // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
527
578
  // debug statements in the shader, and then call this function after encoding the commands and submitting them.
528
- static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
579
+ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
529
580
  wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
530
581
  encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
531
582
  wgpu::CommandBuffer commands = encoder.Finish();
@@ -537,53 +588,32 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
537
588
  }
538
589
  #endif
539
590
 
540
- static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) {
591
+ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
592
+ std::vector<webgpu_command> & commands,
593
+ webgpu_buf_pool & param_buf_pool) {
541
594
  std::vector<wgpu::CommandBuffer> command_buffers;
542
- std::vector<webgpu_pool_bufs> params_bufs;
543
- std::vector<webgpu_pool_bufs> set_rows_error_bufs;
595
+ std::vector<wgpu::Buffer> params_bufs;
596
+ webgpu_submission submission;
544
597
  #ifdef GGML_WEBGPU_GPU_PROFILE
545
598
  std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
546
599
  #endif
547
600
 
548
601
  for (const auto & command : commands) {
549
602
  command_buffers.push_back(command.commands);
550
- params_bufs.push_back(command.params_bufs);
551
- if (command.set_rows_error_bufs) {
552
- set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
553
- }
603
+ params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
554
604
  }
555
605
  ctx->queue.Submit(command_buffers.size(), command_buffers.data());
556
606
 
557
- std::vector<wgpu::FutureWaitInfo> futures;
558
-
559
607
  wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
560
608
  wgpu::CallbackMode::AllowSpontaneous,
561
- [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
609
+ [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
562
610
  if (status != wgpu::QueueWorkDoneStatus::Success) {
563
611
  GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
564
612
  }
565
613
  // Free the staged buffers
566
- ctx->param_buf_pool.free_bufs({ params_bufs });
614
+ param_buf_pool.free_bufs(params_bufs);
567
615
  });
568
- futures.push_back({ p_f });
569
-
570
- for (const auto & bufs : set_rows_error_bufs) {
571
- wgpu::Future f = bufs.host_buf.MapAsync(
572
- wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
573
- [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
574
- if (status != wgpu::MapAsyncStatus::Success) {
575
- GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
576
- } else {
577
- const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
578
- if (*error_data) {
579
- GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
580
- }
581
- // We can't unmap in here due to WebGPU reentrancy limitations.
582
- ctx->set_rows_error_buf_pool.free_bufs({ bufs });
583
- }
584
- });
585
- futures.push_back({ f });
586
- }
616
+ submission.submit_done = { p_f };
587
617
 
588
618
  #ifdef GGML_WEBGPU_GPU_PROFILE
589
619
  for (const auto & command : commands) {
@@ -600,52 +630,54 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx,
600
630
  // WebGPU timestamps are in ns; convert to ms
601
631
  double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
602
632
  ctx->shader_gpu_time_ms[label] += elapsed_ms;
603
- // We can't unmap in here due to WebGPU reentrancy limitations.
604
- ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
605
633
  }
634
+ // We can't unmap in here due to WebGPU reentrancy limitations.
635
+ ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
606
636
  });
607
- futures.push_back({ f });
637
+ submission.profile_futures.push_back({ f });
608
638
  }
609
639
  #endif
610
- return { futures };
640
+ return submission;
611
641
  }
612
642
 
613
- static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
614
- webgpu_pipeline & pipeline,
615
- std::vector<uint32_t> params,
616
- std::vector<wgpu::BindGroupEntry> bind_group_entries,
617
- uint32_t wg_x,
618
- uint32_t wg_y = 1,
619
- std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
620
- webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
621
-
622
- ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
623
- uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
624
- for (size_t i = 0; i < params.size(); i++) {
625
- _params[i] = params[i];
626
- };
643
+ static webgpu_command ggml_backend_webgpu_build_multi(
644
+ webgpu_global_context & ctx,
645
+ webgpu_buf_pool & param_buf_pool,
646
+ const std::vector<webgpu_pipeline> & pipelines,
647
+ const std::vector<std::vector<uint32_t>> & params_list,
648
+ const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
649
+ const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
650
+ GGML_ASSERT(pipelines.size() == params_list.size());
651
+ GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
652
+ GGML_ASSERT(pipelines.size() == workgroups_list.size());
653
+
654
+ std::vector<wgpu::Buffer> params_bufs_list;
655
+ std::vector<wgpu::BindGroup> bind_groups;
627
656
 
628
- params_bufs.host_buf.Unmap();
657
+ for (size_t i = 0; i < pipelines.size(); i++) {
658
+ wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
629
659
 
630
- uint32_t params_bufs_binding_num = bind_group_entries.size();
631
- bind_group_entries.push_back({ .binding = params_bufs_binding_num,
632
- .buffer = params_bufs.dev_buf,
633
- .offset = 0,
634
- .size = params_bufs.dev_buf.GetSize() });
660
+ std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
661
+ uint32_t params_binding_num = entries.size();
662
+ entries.push_back(
663
+ { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
635
664
 
636
- wgpu::BindGroupDescriptor bind_group_desc;
637
- bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0);
638
- bind_group_desc.entryCount = bind_group_entries.size();
639
- bind_group_desc.entries = bind_group_entries.data();
640
- bind_group_desc.label = pipeline.name.c_str();
641
- wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
665
+ wgpu::BindGroupDescriptor bind_group_desc;
666
+ bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
667
+ bind_group_desc.entryCount = entries.size();
668
+ bind_group_desc.entries = entries.data();
669
+ bind_group_desc.label = pipelines[i].name.c_str();
670
+ bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
671
+
672
+ params_bufs_list.push_back(params_bufs);
673
+ }
642
674
 
643
675
  wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
644
- encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
676
+ for (size_t i = 0; i < params_bufs_list.size(); i++) {
677
+ ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
678
+ }
645
679
 
646
680
  #ifdef GGML_WEBGPU_GPU_PROFILE
647
- // --- Profiling: GPU timestamp queries ---
648
- // Allocate a timestamp query buffer (2 timestamps: start/end)
649
681
  webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
650
682
  if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
651
683
  ts_bufs.host_buf.Unmap();
@@ -659,50 +691,63 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
659
691
  #else
660
692
  wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
661
693
  #endif
662
- pass.SetPipeline(pipeline.pipeline);
663
- pass.SetBindGroup(0, bind_group);
664
- pass.DispatchWorkgroups(wg_x, wg_y, 1);
694
+ for (size_t i = 0; i < pipelines.size(); i++) {
695
+ pass.SetPipeline(pipelines[i].pipeline);
696
+ pass.SetBindGroup(0, bind_groups[i]);
697
+ pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
698
+ }
665
699
  pass.End();
666
700
 
667
701
  #ifdef GGML_WEBGPU_GPU_PROFILE
668
- // Resolve the query set into the device buffer
669
702
  encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
670
703
  encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
671
704
  #endif
672
705
 
673
- // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
674
- if (set_rows_error_bufs) {
675
- encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
676
- set_rows_error_bufs->host_buf.GetSize());
677
- }
678
-
679
706
  wgpu::CommandBuffer commands = encoder.Finish();
680
707
  webgpu_command result = {};
681
708
  result.commands = commands;
682
- result.params_bufs = params_bufs;
683
- result.set_rows_error_bufs = set_rows_error_bufs;
709
+ result.params_bufs = params_bufs_list;
710
+ result.num_kernels = pipelines.size();
684
711
  #ifdef GGML_WEBGPU_GPU_PROFILE
685
712
  result.timestamp_query_bufs = ts_bufs;
686
- result.pipeline_name = pipeline.name;
713
+ // TODO: handle multiple pipeline names
714
+ result.pipeline_name = pipelines.front().name;
687
715
  #endif
688
716
  return result;
689
717
  }
690
718
 
691
- static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
692
- wgpu::Buffer & buf,
693
- uint32_t value,
694
- size_t offset,
695
- size_t size) {
719
+ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx,
720
+ webgpu_buf_pool & param_buf_pool,
721
+ webgpu_pipeline & pipeline,
722
+ std::vector<uint32_t> params,
723
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
724
+ uint32_t wg_x,
725
+ uint32_t wg_y = 1) {
726
+ return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
727
+ {
728
+ pipeline
729
+ },
730
+ { std::move(params) }, { std::move(bind_group_entries) },
731
+ { { wg_x, wg_y } });
732
+ }
733
+
734
+ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
735
+ wgpu::Buffer & buf,
736
+ uint32_t value,
737
+ size_t offset,
738
+ size_t size) {
696
739
  std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
697
740
  std::vector<wgpu::BindGroupEntry> entries = {
698
741
  { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
699
742
  };
700
- size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread;
743
+ size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
701
744
  uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
702
745
 
703
- webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x);
704
- std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
705
- ggml_backend_webgpu_wait(ctx, futures);
746
+ webgpu_command command =
747
+ ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
748
+ std::vector<webgpu_command> commands = { command };
749
+ std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
750
+ ggml_backend_webgpu_wait(ctx, sub);
706
751
  }
707
752
 
708
753
  /** End WebGPU Actions */
@@ -714,7 +759,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
714
759
  return ctx->name.c_str();
715
760
  }
716
761
 
717
- // TODO: implement proper cleanup
718
762
  static void ggml_backend_webgpu_free(ggml_backend_t backend) {
719
763
  ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
720
764
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
@@ -722,19 +766,19 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
722
766
  #ifdef GGML_WEBGPU_CPU_PROFILE
723
767
  std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
724
768
  double total_cpu = 0.0;
725
- for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
769
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
726
770
  total_cpu += kv.second;
727
771
  }
728
772
  std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
729
773
  std::cout << "ggml_webgpu: cpu breakdown:\n";
730
- for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
774
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
731
775
  double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
732
776
  std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
733
777
  }
734
- if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) {
778
+ if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
735
779
  std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
736
780
  }
737
- for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) {
781
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
738
782
  double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
739
783
  std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
740
784
  }
@@ -743,14 +787,15 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
743
787
  #ifdef GGML_WEBGPU_GPU_PROFILE
744
788
  std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
745
789
  double total_gpu = 0.0;
746
- for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
790
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
747
791
  total_gpu += kv.second;
748
792
  }
749
793
  std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
750
794
  std::cout << "\nggml_webgpu: gpu breakdown:\n";
751
- for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
795
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
752
796
  double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
753
- std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
797
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
798
+ << pct << "%)\n";
754
799
  }
755
800
  #endif
756
801
 
@@ -758,9 +803,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
758
803
  std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
759
804
  #endif
760
805
 
761
- #if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
762
- GGML_UNUSED(ctx);
763
- #endif
806
+ delete ctx;
807
+ delete backend;
764
808
  }
765
809
 
766
810
  static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
@@ -774,12 +818,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
774
818
 
775
819
  static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
776
820
  size_t offset = ggml_webgpu_tensor_offset(t);
777
- return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
821
+ return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
778
822
  }
779
823
 
780
824
  static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
781
825
  size_t offset = ggml_webgpu_tensor_offset(t);
782
- return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
826
+ return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
783
827
  }
784
828
 
785
829
  static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
@@ -792,6 +836,30 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
792
836
  (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
793
837
  }
794
838
 
839
+ // Used to determine if two tensors share the same buffer and their byte ranges overlap,
840
+ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
841
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
842
+ ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
843
+ ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
844
+ }
845
+
846
+ struct binary_overlap_flags {
847
+ bool inplace; // src0 == dst
848
+ bool overlap; // src1 == dst
849
+ bool src_overlap;
850
+ };
851
+
852
+ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
853
+ ggml_tensor * src1,
854
+ ggml_tensor * dst) {
855
+ binary_overlap_flags flags = {};
856
+ flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
857
+ flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
858
+ flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
859
+
860
+ return flags;
861
+ }
862
+
795
863
  static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
796
864
  uint32_t ne = (uint32_t) ggml_nelements(dst);
797
865
 
@@ -820,22 +888,85 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
820
888
  };
821
889
 
822
890
  uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
823
- return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
891
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
892
+ params, entries, wg_x);
893
+ }
894
+
895
+ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
896
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
897
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
898
+ };
899
+
900
+ webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
901
+
902
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
903
+
904
+ const uint32_t ne = (uint32_t) ggml_nelements(dst);
905
+
906
+ std::vector<uint32_t> params = {
907
+ ne,
908
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
909
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
910
+ // Strides (in elements)
911
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
912
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
913
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
914
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
915
+ // Shapes
916
+ (uint32_t) src->ne[0],
917
+ (uint32_t) src->ne[1],
918
+ (uint32_t) src->ne[2],
919
+ (uint32_t) src->ne[3],
920
+ (uint32_t) dst->ne[0],
921
+ (uint32_t) dst->ne[1],
922
+ (uint32_t) dst->ne[2],
923
+ (uint32_t) dst->ne[3],
924
+ // Pad sizes
925
+ (uint32_t) ggml_get_op_params_i32(dst, 0),
926
+ (uint32_t) ggml_get_op_params_i32(dst, 1),
927
+ (uint32_t) ggml_get_op_params_i32(dst, 2),
928
+ (uint32_t) ggml_get_op_params_i32(dst, 3),
929
+ (uint32_t) ggml_get_op_params_i32(dst, 4),
930
+ (uint32_t) ggml_get_op_params_i32(dst, 5),
931
+ (uint32_t) ggml_get_op_params_i32(dst, 6),
932
+ (uint32_t) ggml_get_op_params_i32(dst, 7),
933
+ };
934
+
935
+ std::vector<wgpu::BindGroupEntry> entries = {
936
+ { .binding = 0,
937
+ .buffer = ggml_webgpu_tensor_buf(src),
938
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
939
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
940
+ { .binding = 1,
941
+ .buffer = ggml_webgpu_tensor_buf(dst),
942
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
943
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
944
+ };
945
+
946
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
947
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
824
948
  }
825
949
 
826
950
  static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
827
951
  ggml_tensor * src,
828
952
  ggml_tensor * idx,
829
953
  ggml_tensor * dst) {
830
- // For set rows specifically, we need to check if src and idx are empty tensors.
954
+ // For set rows specifically, we need to check if src and idx are empty
955
+ // tensors.
831
956
  if (ggml_is_empty(src) || ggml_is_empty(idx)) {
832
957
  return std::nullopt;
833
958
  }
834
959
 
835
- webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
836
- if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
837
- error_bufs.host_buf.Unmap();
838
- }
960
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
961
+ .src0 = src,
962
+ .src1 = idx,
963
+ .dst = dst,
964
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
965
+ };
966
+
967
+ webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
968
+
969
+ auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
839
970
 
840
971
  std::vector<uint32_t> params = {
841
972
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@@ -865,44 +996,67 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
865
996
  { .binding = 2,
866
997
  .buffer = ggml_webgpu_tensor_buf(dst),
867
998
  .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
868
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
869
- { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
999
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
870
1000
  };
871
1001
 
872
- int vectorized = src->ne[0] % 4 == 0;
873
- webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized];
874
- uint32_t threads;
875
- if (vectorized) {
1002
+ if (decisions->i64_idx) {
1003
+ entries.push_back({ .binding = 3,
1004
+ .buffer = ctx->set_rows_dev_error_buf,
1005
+ .offset = 0,
1006
+ .size = ctx->set_rows_dev_error_buf.GetSize() });
1007
+ }
1008
+
1009
+ uint32_t threads;
1010
+ if (decisions->vec4) {
876
1011
  threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
877
1012
  } else {
878
1013
  threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
879
1014
  }
1015
+ uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
1016
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
1017
+ }
880
1018
 
881
- uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE);
882
-
883
- return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
1019
+ // Workgroup size is a common constant
1020
+ static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
1021
+ std::vector<wgpu::ConstantEntry> constants(1);
1022
+ constants[0].key = "wg_size";
1023
+ constants[0].value = wg_size;
1024
+ return constants;
884
1025
  }
885
1026
 
886
1027
  static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
887
1028
  ggml_tensor * src,
888
1029
  ggml_tensor * idx,
889
1030
  ggml_tensor * dst) {
890
- std::vector<uint32_t> params = {
891
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
892
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
893
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
894
- // Convert byte-strides to element-strides
895
- (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
896
- (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
897
- (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
898
- (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
899
- (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
900
- // Shape of dst
901
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
902
- // Shape of idx
903
- (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
1031
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1032
+ .src0 = src,
1033
+ .src1 = nullptr,
1034
+ .dst = dst,
1035
+ .max_wg_size = WEBGPU_MAX_WG_SIZE,
904
1036
  };
905
1037
 
1038
+ webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
1039
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1040
+
1041
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1042
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
1043
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1044
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1045
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1046
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1047
+ (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
1048
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
1049
+ (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
1050
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1051
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1052
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1053
+ (uint32_t) dst->ne[0],
1054
+ (uint32_t) dst->ne[1],
1055
+ (uint32_t) dst->ne[2],
1056
+ (uint32_t) dst->ne[3],
1057
+ (uint32_t) (idx->ne[1]),
1058
+ (uint32_t) (idx->ne[2]) };
1059
+
906
1060
  std::vector<wgpu::BindGroupEntry> entries = {
907
1061
  { .binding = 0,
908
1062
  .buffer = ggml_webgpu_tensor_buf(src),
@@ -918,68 +1072,45 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
918
1072
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
919
1073
  };
920
1074
 
921
- uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
1075
+ uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
922
1076
 
923
- uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
924
- webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
925
- return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1077
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
926
1078
  }
927
1079
 
928
1080
  static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
929
1081
  ggml_tensor * src0,
930
1082
  ggml_tensor * src1,
931
1083
  ggml_tensor * dst) {
932
- std::vector<uint32_t> params = {
933
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
934
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
935
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
936
- (uint32_t) dst->ne[0], // number of rows in result (M, transposed)
937
- (uint32_t) dst->ne[1], // number of columns in result (N)
938
- (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
939
- (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
940
- (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
941
- (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
942
- (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
943
- (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
944
- (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
945
- (uint32_t) src0->ne[2], // batch size in dimension 2
946
- (uint32_t) src0->ne[3], // batch size in dimension 3
947
- (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
948
- (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
949
- };
950
-
951
- std::vector<wgpu::BindGroupEntry> entries = {
952
- { .binding = 0,
953
- .buffer = ggml_webgpu_tensor_buf(src0),
954
- .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
955
- .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
956
- { .binding = 1,
957
- .buffer = ggml_webgpu_tensor_buf(src1),
958
- .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
959
- .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
960
- { .binding = 2,
961
- .buffer = ggml_webgpu_tensor_buf(dst),
962
- .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
963
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
964
- };
965
-
966
- webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
967
-
968
- uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
969
- uint32_t wg_y = 1;
1084
+ // Determine if this is a mat-vec operation
1085
+ bool is_vec = (dst->ne[1] == 1);
970
1086
 
1087
+ // Determine if we should use fast path
971
1088
  bool use_fast = false;
972
1089
  switch (src1->type) {
973
1090
  case GGML_TYPE_F16:
974
1091
  use_fast = (src0->type == GGML_TYPE_F16);
975
1092
  break;
976
1093
  case GGML_TYPE_F32:
1094
+ // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
977
1095
  switch (src0->type) {
978
1096
  case GGML_TYPE_F32:
979
1097
  case GGML_TYPE_F16:
980
1098
  case GGML_TYPE_Q4_0:
1099
+ case GGML_TYPE_Q4_1:
1100
+ case GGML_TYPE_Q5_0:
1101
+ case GGML_TYPE_Q5_1:
1102
+ case GGML_TYPE_Q8_0:
1103
+ case GGML_TYPE_Q8_1:
1104
+ case GGML_TYPE_Q6_K:
981
1105
  use_fast = true;
982
1106
  break;
1107
+ case GGML_TYPE_Q2_K:
1108
+ case GGML_TYPE_Q3_K:
1109
+ case GGML_TYPE_Q4_K:
1110
+ case GGML_TYPE_Q5_K:
1111
+ // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
1112
+ use_fast = !is_vec;
1113
+ break;
983
1114
  default:
984
1115
  break;
985
1116
  }
@@ -988,44 +1119,110 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
988
1119
  break;
989
1120
  }
990
1121
 
991
- if (use_fast) {
992
- int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
993
- if (dst->ne[1] == 1) {
994
- // We don't support vectorized mul_mat_vec for quantized types
995
- vectorized = vectorized && (src0->type < 2);
996
- pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
997
- uint32_t batches = dst->ne[2] * dst->ne[3];
998
- uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
999
- uint32_t total_wg = output_groups * batches;
1000
- wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
1001
- wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension);
1002
- } else {
1003
- pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
1004
- uint32_t wg_m;
1005
- uint32_t wg_n;
1006
- #ifndef __EMSCRIPTEN__
1007
- if (ctx->supports_subgroup_matrix) {
1008
- // The total number of subgroups/workgroups needed per matrix.
1009
- uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
1010
- wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1011
- uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
1012
- wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1013
- } else {
1014
- #endif
1015
- uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
1016
- uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
1017
- wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1018
- wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
1019
- #ifndef __EMSCRIPTEN__
1020
- }
1021
- #endif
1122
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1123
+ .src0 = src0,
1124
+ .src1 = src1,
1125
+ .dst = dst,
1126
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1127
+ .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
1128
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1129
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1130
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1131
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
1132
+ };
1133
+
1134
+ // Get or create pipeline
1135
+ webgpu_pipeline pipeline;
1136
+
1137
+ if (use_fast && is_vec) {
1138
+ pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
1139
+ } else if (use_fast) {
1140
+ pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
1141
+ } else {
1142
+ pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
1143
+ }
1144
+
1145
+ // Build params
1146
+ std::vector<uint32_t> params = {
1147
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1148
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1149
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1150
+ (uint32_t) dst->ne[0],
1151
+ (uint32_t) dst->ne[1],
1152
+ (uint32_t) src0->ne[0],
1153
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1154
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1155
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1156
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1157
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1158
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1159
+ (uint32_t) src0->ne[2],
1160
+ (uint32_t) src0->ne[3],
1161
+ (uint32_t) (src1->ne[2] / src0->ne[2]),
1162
+ (uint32_t) (src1->ne[3] / src0->ne[3])
1163
+ };
1164
+
1165
+ // Build bind group entries
1166
+ std::vector<wgpu::BindGroupEntry> entries = {
1167
+ { .binding = 0,
1168
+ .buffer = ggml_webgpu_tensor_buf(src0),
1169
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1170
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1171
+ { .binding = 1,
1172
+ .buffer = ggml_webgpu_tensor_buf(src1),
1173
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1174
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
1175
+ { .binding = 2,
1176
+ .buffer = ggml_webgpu_tensor_buf(dst),
1177
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1178
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
1179
+ };
1022
1180
 
1023
- wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1181
+ // Calculate workgroup dimensions
1182
+ uint32_t wg_x = 1;
1183
+ uint32_t wg_y = 1;
1184
+ const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1185
+
1186
+ if (use_fast && is_vec) {
1187
+ auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
1188
+
1189
+ uint32_t batches = dst->ne[2] * dst->ne[3];
1190
+ uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
1191
+ uint32_t total_wg = output_groups * batches;
1192
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1193
+ } else if (use_fast) {
1194
+ auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
1195
+
1196
+ // Fast-path tiled/subgroup calculations
1197
+ uint32_t wg_m;
1198
+ uint32_t wg_n;
1199
+ if (decisions->use_subgroup_matrix) {
1200
+ uint32_t wg_m_sg_tile =
1201
+ decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
1202
+ wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1203
+ uint32_t wg_n_sg_tile =
1204
+ decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
1205
+ wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1206
+ } else {
1207
+ uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
1208
+ uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
1209
+ wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1210
+ wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
1024
1211
  }
1212
+ uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1213
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1214
+
1215
+ } else { // legacy
1216
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1217
+ uint32_t wg_size = decisions->wg_size;
1218
+ uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
1219
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1025
1220
  }
1026
- return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
1221
+
1222
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
1027
1223
  }
1028
1224
 
1225
+ #ifndef __EMSCRIPTEN__
1029
1226
  static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1030
1227
  ggml_tensor * Q,
1031
1228
  ggml_tensor * K,
@@ -1109,105 +1306,97 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1109
1306
  .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1110
1307
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1111
1308
 
1112
- bool kv_direct =
1113
- (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
1114
-
1115
- flash_attn_pipeline_key key = {
1116
- .q_type = Q->type,
1117
- .kv_type = K->type,
1118
- .dst_type = dst->type,
1119
- .head_dim_qk = (uint32_t) Q->ne[0],
1120
- .head_dim_v = (uint32_t) V->ne[0],
1121
- .kv_direct = kv_direct,
1122
- .has_mask = static_cast<bool>(has_mask),
1123
- .has_sinks = static_cast<bool>(has_sinks),
1124
- .uses_logit_softcap = logit_softcap != 0.0f,
1309
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1310
+ .src0 = Q,
1311
+ .src1 = K,
1312
+ .src2 = V,
1313
+ .src3 = mask,
1314
+ .src4 = sinks,
1315
+ .dst = dst,
1316
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1317
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1318
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1319
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1320
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1321
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
1125
1322
  };
1126
1323
 
1127
- webgpu_pipeline pipeline;
1128
- ggml_webgpu_flash_attn_shader_decisions decisions = {};
1324
+ webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
1129
1325
 
1130
- auto it = ctx->flash_attn_pipelines.find(key);
1131
- if (it != ctx->flash_attn_pipelines.end()) {
1132
- pipeline = it->second;
1133
- decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
1134
- } else {
1135
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
1136
- it = ctx->flash_attn_pipelines.find(key);
1137
- if (it != ctx->flash_attn_pipelines.end()) {
1138
- pipeline = it->second;
1139
- decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
1140
- } else {
1141
- ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
1142
- .head_dim_qk = (uint32_t) Q->ne[0],
1143
- .head_dim_v = (uint32_t) V->ne[0],
1144
- .kv_direct = kv_direct,
1145
- .has_mask = static_cast<bool>(has_mask),
1146
- .has_sinks = static_cast<bool>(has_sinks),
1147
- .uses_logit_softcap = logit_softcap != 0.0f,
1148
- .sg_mat_m = ctx->sg_mat_m,
1149
- .sg_mat_n = ctx->sg_mat_n,
1150
- .sg_mat_k = ctx->sg_mat_k,
1151
- .wg_mem_limit_bytes =
1152
- ctx->limits.maxComputeWorkgroupStorageSize,
1153
- .max_subgroup_size = ctx->max_subgroup_size };
1154
-
1155
- ggml_webgpu_processed_shader processed =
1156
- ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
1157
- pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1158
- pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
1159
- ctx->flash_attn_pipelines.emplace(key, pipeline);
1160
- decisions = processed.decisions;
1161
- }
1162
- }
1326
+ auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
1163
1327
 
1164
- uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
1328
+ uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
1165
1329
  uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
1166
- return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1330
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1167
1331
  }
1332
+ #endif
1168
1333
 
1169
1334
  static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1170
- uint32_t ne = (uint32_t) ggml_nelements(dst);
1171
- ggml_unary_op unary_op = ggml_get_unary_op(dst);
1172
- uint32_t inplace = ggml_webgpu_tensor_equal(src, dst);
1173
-
1174
- std::vector<uint32_t> params = {
1175
- ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1176
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1177
- // Convert byte-strides to element-strides
1178
- (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1179
- (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1180
- (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1181
- (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1182
- // Logical shapes
1183
- (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
1184
- (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
1335
+ bool is_unary = dst->op == GGML_OP_UNARY;
1336
+ bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
1337
+
1338
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1339
+ .src0 = src,
1340
+ .src1 = nullptr,
1341
+ .dst = dst,
1342
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1343
+ .inplace = inplace,
1185
1344
  };
1186
1345
 
1187
- switch (unary_op) {
1188
- case GGML_UNARY_OP_XIELU:
1189
- {
1190
- // Get float parameters and reinterpret their bit patterns as uint32_t
1191
- // for passing through the params buffer
1192
- float alpha_n = ggml_get_op_params_f32(dst, 1);
1193
- float alpha_p = ggml_get_op_params_f32(dst, 2);
1194
- float beta = ggml_get_op_params_f32(dst, 3);
1195
- float eps = ggml_get_op_params_f32(dst, 4);
1196
- params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1197
- params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1198
- params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1199
- params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1346
+ webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
1347
+
1348
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1349
+
1350
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1351
+
1352
+ std::vector<uint32_t> params = { ne,
1353
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1354
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1355
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
1356
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1357
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1358
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1359
+ (uint32_t) src->ne[0],
1360
+ (uint32_t) src->ne[1],
1361
+ (uint32_t) src->ne[2] };
1362
+
1363
+ ggml_tensor * effective_src = src;
1364
+ if (is_unary) {
1365
+ ggml_unary_op unary_op = ggml_get_unary_op(dst);
1366
+ switch (unary_op) {
1367
+ case GGML_UNARY_OP_XIELU:
1368
+ {
1369
+ // Get float parameters and reinterpret their bit patterns as uint32_t
1370
+ // for passing through the params buffer
1371
+ float alpha_n = ggml_get_op_params_f32(dst, 1);
1372
+ float alpha_p = ggml_get_op_params_f32(dst, 2);
1373
+ float beta = ggml_get_op_params_f32(dst, 3);
1374
+ float eps = ggml_get_op_params_f32(dst, 4);
1375
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1376
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1377
+ params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1378
+ params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1379
+ break;
1380
+ }
1381
+ default:
1200
1382
  break;
1201
- }
1202
- default:
1203
- break;
1383
+ }
1384
+ } else if (dst->op == GGML_OP_CLAMP) {
1385
+ float clamp_min = ggml_get_op_params_f32(dst, 0);
1386
+ float clamp_max = ggml_get_op_params_f32(dst, 1);
1387
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
1388
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
1389
+ } else if (dst->op == GGML_OP_FILL) {
1390
+ float fill_val = ggml_get_op_params_f32(dst, 0);
1391
+ params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
1392
+ effective_src = dst; // fill simply fills dst
1204
1393
  }
1205
1394
 
1206
1395
  std::vector<wgpu::BindGroupEntry> entries = {
1207
1396
  { .binding = 0,
1208
- .buffer = ggml_webgpu_tensor_buf(src),
1209
- .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1210
- .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1397
+ .buffer = ggml_webgpu_tensor_buf(effective_src),
1398
+ .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src),
1399
+ .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
1211
1400
  };
1212
1401
  if (!inplace) {
1213
1402
  entries.push_back({ .binding = 1,
@@ -1216,21 +1405,54 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
1216
1405
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1217
1406
  }
1218
1407
 
1219
- uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
1220
- return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x);
1408
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1409
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1221
1410
  }
1222
1411
 
1223
- static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1224
- ggml_tensor * src0,
1225
- ggml_tensor * src1,
1226
- ggml_tensor * dst,
1227
- webgpu_pipeline & pipeline,
1228
- bool inplace) {
1412
+ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1413
+ ggml_tensor * src0,
1414
+ ggml_tensor * src1,
1415
+ ggml_tensor * dst) {
1416
+ binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
1417
+
1418
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1419
+ .src0 = src0,
1420
+ .src1 = src1,
1421
+ .dst = dst,
1422
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1423
+ .inplace = flags.inplace,
1424
+ .overlap = flags.overlap,
1425
+ .src_overlap = flags.src_overlap,
1426
+ };
1427
+
1428
+ webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
1429
+
1430
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1431
+
1432
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1433
+
1434
+ size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
1435
+ size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
1436
+
1437
+ uint32_t offset_merged_src0 = 0;
1438
+ uint32_t offset_merged_src1 = 0;
1439
+ if (flags.src_overlap) {
1440
+ size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
1441
+ offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
1442
+ offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
1443
+ }
1444
+
1229
1445
  std::vector<uint32_t> params = {
1230
- (uint32_t) ggml_nelements(dst),
1446
+ ne,
1231
1447
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1232
1448
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1233
1449
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1450
+ offset_merged_src0,
1451
+ offset_merged_src1,
1452
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
1453
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1454
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1455
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1234
1456
  (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
1235
1457
  (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1236
1458
  (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
@@ -1244,6 +1466,79 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1244
1466
  (uint32_t) src1->ne[3],
1245
1467
  };
1246
1468
 
1469
+ std::vector<wgpu::BindGroupEntry> entries;
1470
+
1471
+ if (flags.src_overlap) {
1472
+ size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
1473
+ size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
1474
+ src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
1475
+ entries.push_back({
1476
+ .binding = 0,
1477
+ .buffer = ggml_webgpu_tensor_buf(src0),
1478
+ .offset = merged_offset,
1479
+ .size = merged_end - merged_offset,
1480
+ });
1481
+ entries.push_back({
1482
+ .binding = 1,
1483
+ .buffer = ggml_webgpu_tensor_buf(dst),
1484
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1485
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst),
1486
+ });
1487
+ } else {
1488
+ entries.push_back({
1489
+ .binding = 0,
1490
+ .buffer = ggml_webgpu_tensor_buf(src0),
1491
+ .offset = src0_webgpu_tensor_align_offset,
1492
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0),
1493
+ });
1494
+ entries.push_back({
1495
+ .binding = 1,
1496
+ .buffer = ggml_webgpu_tensor_buf(src1),
1497
+ .offset = src1_webgpu_tensor_align_offset,
1498
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1),
1499
+ });
1500
+ if (!flags.inplace && !flags.overlap) {
1501
+ entries.push_back({
1502
+ .binding = 2,
1503
+ .buffer = ggml_webgpu_tensor_buf(dst),
1504
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1505
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst),
1506
+ });
1507
+ }
1508
+ }
1509
+
1510
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1511
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1512
+ }
1513
+
1514
+ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
1515
+ ggml_tensor * src0,
1516
+ ggml_tensor * src1,
1517
+ ggml_tensor * dst) {
1518
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1519
+ uint32_t dim = (uint32_t) dst->op_params[0];
1520
+
1521
+ std::vector<uint32_t> params = {
1522
+ ne,
1523
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1524
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1525
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1526
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
1527
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1528
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1529
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1530
+ (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
1531
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1532
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1533
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1534
+ (uint32_t) dst->ne[0],
1535
+ (uint32_t) dst->ne[1],
1536
+ (uint32_t) dst->ne[2],
1537
+ (uint32_t) dst->ne[3],
1538
+ dim,
1539
+ (uint32_t) src0->ne[dim]
1540
+ };
1541
+
1247
1542
  std::vector<wgpu::BindGroupEntry> entries = {
1248
1543
  { .binding = 0,
1249
1544
  .buffer = ggml_webgpu_tensor_buf(src0),
@@ -1252,17 +1547,66 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1252
1547
  { .binding = 1,
1253
1548
  .buffer = ggml_webgpu_tensor_buf(src1),
1254
1549
  .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1255
- .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
1550
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
1551
+ { .binding = 2,
1552
+ .buffer = ggml_webgpu_tensor_buf(dst),
1553
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1554
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1555
+ };
1556
+
1557
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1558
+ .src0 = src0,
1559
+ .src1 = src1,
1560
+ .dst = dst,
1561
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1562
+ };
1563
+
1564
+ webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
1565
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1566
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1567
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1568
+ }
1569
+
1570
+ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
1571
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1572
+
1573
+ std::vector<uint32_t> params = { ne,
1574
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
1575
+ ggml_type_size(src0->type)),
1576
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1577
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
1578
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1579
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1580
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1581
+ (uint32_t) (src0->ne[0]),
1582
+ (uint32_t) (src0->ne[1]),
1583
+ (uint32_t) (src0->ne[2]),
1584
+ (uint32_t) (src0->ne[3]),
1585
+ (uint32_t) (dst->ne[0]),
1586
+ (uint32_t) (dst->ne[1]),
1587
+ (uint32_t) (dst->ne[2]) };
1588
+
1589
+ std::vector<wgpu::BindGroupEntry> entries = {
1590
+ { .binding = 0,
1591
+ .buffer = ggml_webgpu_tensor_buf(src0),
1592
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1593
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1594
+ { .binding = 1,
1595
+ .buffer = ggml_webgpu_tensor_buf(dst),
1596
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1597
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1598
+ };
1599
+
1600
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1601
+ .src0 = src0,
1602
+ .dst = dst,
1603
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1256
1604
  };
1257
- if (!inplace) {
1258
- entries.push_back({ .binding = 2,
1259
- .buffer = ggml_webgpu_tensor_buf(dst),
1260
- .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1261
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1262
- }
1263
1605
 
1264
- uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1265
- return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1606
+ webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
1607
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1608
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1609
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1266
1610
  }
1267
1611
 
1268
1612
  static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -1297,7 +1641,8 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
1297
1641
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1298
1642
  }
1299
1643
 
1300
- return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src));
1644
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
1645
+ entries, ggml_nrows(src));
1301
1646
  }
1302
1647
 
1303
1648
  static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@@ -1312,7 +1657,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1312
1657
  const int mode = ((int32_t *) dst->op_params)[2];
1313
1658
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1314
1659
 
1315
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1660
+ float freq_base;
1661
+ float freq_scale;
1662
+ float ext_factor;
1663
+ float attn_factor;
1664
+ float beta_fast;
1665
+ float beta_slow;
1316
1666
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1317
1667
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1318
1668
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -1384,7 +1734,7 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1384
1734
 
1385
1735
  webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
1386
1736
  uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1387
- return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1737
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1388
1738
  }
1389
1739
 
1390
1740
  static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -1436,12 +1786,24 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
1436
1786
 
1437
1787
  webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
1438
1788
  uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1439
- return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1789
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1440
1790
  }
1441
1791
 
1442
1792
  static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1443
- int inplace = ggml_webgpu_tensor_equal(src, dst);
1793
+ bool inplace = ggml_webgpu_tensor_equal(src, dst);
1794
+
1795
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1796
+ .src0 = src,
1797
+ .src1 = nullptr,
1798
+ .dst = dst,
1799
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1800
+ .inplace = inplace,
1801
+ };
1444
1802
 
1803
+ webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
1804
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1805
+
1806
+ // params unchanged
1445
1807
  std::vector<uint32_t> params = {
1446
1808
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1447
1809
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1459,12 +1821,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
1459
1821
  *(uint32_t *) &dst->op_params[1] // bias
1460
1822
  };
1461
1823
 
1824
+ // bindgroups unchanged
1462
1825
  std::vector<wgpu::BindGroupEntry> entries = {
1463
1826
  { .binding = 0,
1464
1827
  .buffer = ggml_webgpu_tensor_buf(src),
1465
1828
  .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1466
1829
  .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1467
1830
  };
1831
+
1468
1832
  if (!inplace) {
1469
1833
  entries.push_back({ .binding = 1,
1470
1834
  .buffer = ggml_webgpu_tensor_buf(dst),
@@ -1472,8 +1836,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
1472
1836
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1473
1837
  }
1474
1838
 
1475
- uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1476
- return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x);
1839
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
1840
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1477
1841
  }
1478
1842
 
1479
1843
  static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
@@ -1545,15 +1909,261 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
1545
1909
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1546
1910
  }
1547
1911
 
1548
- return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1912
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
1913
+ ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1549
1914
  ggml_nrows(dst));
1550
1915
  }
1551
1916
 
1917
+ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1918
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1919
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1920
+ (uint32_t) src->ne[0] };
1921
+
1922
+ std::vector<wgpu::BindGroupEntry> entries = {
1923
+ { .binding = 0,
1924
+ .buffer = ggml_webgpu_tensor_buf(src),
1925
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1926
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1927
+ { .binding = 1,
1928
+ .buffer = ggml_webgpu_tensor_buf(dst),
1929
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1930
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1931
+ };
1932
+
1933
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1934
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1935
+ };
1936
+
1937
+ webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
1938
+ uint32_t wg_x = ggml_nelements(dst);
1939
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1940
+ }
1941
+
1942
+ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1943
+ bool is_top_k = dst->op == GGML_OP_TOP_K;
1944
+
1945
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1946
+ .src0 = src,
1947
+ .src1 = nullptr,
1948
+ .dst = dst,
1949
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1950
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1951
+ };
1952
+
1953
+ webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
1954
+ auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
1955
+
1956
+ webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
1957
+
1958
+ const uint32_t src_ne0 = (uint32_t) src->ne[0];
1959
+ const uint32_t nrows = (uint32_t) ggml_nrows(src);
1960
+ const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
1961
+ const uint32_t block_size =
1962
+ is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
1963
+ uint32_t out_ne0 = src_ne0;
1964
+ if (is_top_k) {
1965
+ if (npr > 1) {
1966
+ const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
1967
+ out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
1968
+ } else {
1969
+ out_ne0 = block_size;
1970
+ }
1971
+ }
1972
+
1973
+ uint32_t merge_len = block_size;
1974
+ uint32_t merge_passes = 0;
1975
+ while (merge_len < out_ne0) {
1976
+ merge_len <<= 1;
1977
+ merge_passes++;
1978
+ }
1979
+
1980
+ const bool start_in_tmp = (merge_passes % 2) == 1;
1981
+
1982
+ const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
1983
+ const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
1984
+ const size_t tmp_offset =
1985
+ ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1986
+ const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
1987
+ const size_t dst_binding_size =
1988
+ ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
1989
+
1990
+ const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
1991
+ const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
1992
+ const uint32_t offset_tmp = 0;
1993
+ const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
1994
+ const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
1995
+ const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
1996
+ const uint32_t stride_idx1 = out_ne0;
1997
+ const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
1998
+ const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
1999
+
2000
+ std::vector<webgpu_pipeline> pipelines;
2001
+ std::vector<std::vector<uint32_t>> params_list;
2002
+ std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
2003
+ std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
2004
+
2005
+ const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst;
2006
+ const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
2007
+ const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
2008
+
2009
+ std::vector<uint32_t> init_params = {
2010
+ offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1,
2011
+ stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
2012
+ block_size, npr, nrows
2013
+ };
2014
+
2015
+ const uint32_t total_wg_init = npr * nrows;
2016
+ const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2017
+ const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
2018
+ const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
2019
+ std::vector<wgpu::BindGroupEntry> init_entries = {
2020
+ { .binding = 0,
2021
+ .buffer = ggml_webgpu_tensor_buf(src),
2022
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2023
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2024
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
2025
+ };
2026
+
2027
+ pipelines.push_back(argsort_pipeline);
2028
+ params_list.push_back(std::move(init_params));
2029
+ entries_list.push_back(std::move(init_entries));
2030
+ workgroups_list.push_back({ wg_x_init, wg_y_init });
2031
+
2032
+ if (merge_passes == 0) {
2033
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
2034
+ entries_list, workgroups_list);
2035
+ }
2036
+
2037
+ bool in_is_tmp = start_in_tmp;
2038
+ uint32_t len = block_size;
2039
+ while (len < out_ne0) {
2040
+ const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
2041
+
2042
+ const bool out_is_tmp = !in_is_tmp;
2043
+ const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst;
2044
+ const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst;
2045
+ const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
2046
+ const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
2047
+ const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size;
2048
+ const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size;
2049
+ const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
2050
+ const uint32_t stride_out1 = top_k_out;
2051
+ const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
2052
+ const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
2053
+
2054
+ std::vector<uint32_t> merge_params = { offset_src,
2055
+ offset_in,
2056
+ offset_out,
2057
+ stride_src1,
2058
+ stride_src2,
2059
+ stride_src3,
2060
+ stride_idx1,
2061
+ stride_idx2,
2062
+ stride_idx3,
2063
+ stride_out1,
2064
+ stride_out2,
2065
+ stride_out3,
2066
+ out_ne0,
2067
+ (uint32_t) src->ne[1],
2068
+ (uint32_t) src->ne[2],
2069
+ top_k_out,
2070
+ len,
2071
+ nm,
2072
+ nrows };
2073
+
2074
+ std::vector<wgpu::BindGroupEntry> merge_entries = {
2075
+ { .binding = 0,
2076
+ .buffer = ggml_webgpu_tensor_buf(src),
2077
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2078
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2079
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
2080
+ { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
2081
+ };
2082
+
2083
+ const uint32_t total_wg_merge = nm * nrows;
2084
+ const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
2085
+ const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
2086
+ workgroups_list.push_back({ wg_x_merge, wg_y_merge });
2087
+ pipelines.push_back(argsort_merge_pipeline);
2088
+ params_list.push_back(std::move(merge_params));
2089
+ entries_list.push_back(std::move(merge_entries));
2090
+
2091
+ len <<= 1;
2092
+ in_is_tmp = !in_is_tmp;
2093
+ }
2094
+
2095
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
2096
+ workgroups_list);
2097
+ }
2098
+
2099
+ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
2100
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2101
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2102
+ (uint32_t) src->ne[0] };
2103
+
2104
+ std::vector<wgpu::BindGroupEntry> entries = {
2105
+ { .binding = 0,
2106
+ .buffer = ggml_webgpu_tensor_buf(src),
2107
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2108
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2109
+ { .binding = 1,
2110
+ .buffer = ggml_webgpu_tensor_buf(dst),
2111
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
2112
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
2113
+ };
2114
+
2115
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
2116
+ .src0 = src,
2117
+ .src1 = nullptr,
2118
+ .dst = dst,
2119
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
2120
+ };
2121
+
2122
+ webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
2123
+ uint32_t wg_x = ggml_nrows(dst);
2124
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
2125
+ }
2126
+
2127
+ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
2128
+ bool total_sum = dst->op == GGML_OP_SUM;
2129
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2130
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2131
+ total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
2132
+ total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
2133
+ total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
2134
+ total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
2135
+ total_sum ? 1 : (uint32_t) src->ne[1],
2136
+ total_sum ? 1 : (uint32_t) src->ne[2] };
2137
+
2138
+ std::vector<wgpu::BindGroupEntry> entries = {
2139
+ { .binding = 0,
2140
+ .buffer = ggml_webgpu_tensor_buf(src),
2141
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2142
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2143
+ { .binding = 1,
2144
+ .buffer = ggml_webgpu_tensor_buf(dst),
2145
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
2146
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
2147
+ };
2148
+
2149
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
2150
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
2151
+ };
2152
+
2153
+ webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
2154
+
2155
+ uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
2156
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
2157
+ }
2158
+
1552
2159
  // Returns the encoded command, or std::nullopt if the operation is a no-op
1553
2160
  static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
1554
2161
  if (ggml_is_empty(node)) {
1555
2162
  return std::nullopt;
1556
2163
  }
2164
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
2165
+ return std::nullopt;
2166
+ }
1557
2167
  WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
1558
2168
 
1559
2169
  ggml_tensor * src0 = node->src[0];
@@ -1578,27 +2188,20 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
1578
2188
  case GGML_OP_MUL_MAT:
1579
2189
  return ggml_webgpu_mul_mat(ctx, src0, src1, node);
1580
2190
  case GGML_OP_FLASH_ATTN_EXT:
2191
+ #ifndef __EMSCRIPTEN__
1581
2192
  return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
2193
+ #else
2194
+ return std::nullopt;
2195
+ #endif
1582
2196
  case GGML_OP_ADD:
1583
- {
1584
- int inplace = ggml_webgpu_tensor_equal(src0, node);
1585
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
1586
- }
1587
2197
  case GGML_OP_SUB:
1588
- {
1589
- int inplace = ggml_webgpu_tensor_equal(src0, node);
1590
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
1591
- }
1592
2198
  case GGML_OP_MUL:
1593
- {
1594
- int inplace = ggml_webgpu_tensor_equal(src0, node);
1595
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
1596
- }
1597
2199
  case GGML_OP_DIV:
1598
- {
1599
- int inplace = ggml_webgpu_tensor_equal(src0, node);
1600
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
1601
- }
2200
+ return ggml_webgpu_binary_op(ctx, src0, src1, node);
2201
+ case GGML_OP_CONCAT:
2202
+ return ggml_webgpu_concat(ctx, src0, src1, node);
2203
+ case GGML_OP_REPEAT:
2204
+ return ggml_webgpu_repeat(ctx, src0, node);
1602
2205
  case GGML_OP_RMS_NORM:
1603
2206
  return ggml_webgpu_rms_norm(ctx, src0, node);
1604
2207
  case GGML_OP_ROPE:
@@ -1610,7 +2213,27 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
1610
2213
  case GGML_OP_SOFT_MAX:
1611
2214
  return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
1612
2215
  case GGML_OP_UNARY:
2216
+ case GGML_OP_CLAMP:
2217
+ case GGML_OP_FILL:
2218
+ case GGML_OP_LOG:
2219
+ case GGML_OP_SQR:
2220
+ case GGML_OP_SQRT:
2221
+ case GGML_OP_SIN:
2222
+ case GGML_OP_COS:
1613
2223
  return ggml_webgpu_unary_op(ctx, src0, node);
2224
+ case GGML_OP_PAD:
2225
+ return ggml_webgpu_pad(ctx, src0, node);
2226
+ case GGML_OP_ARGMAX:
2227
+ return ggml_webgpu_argmax(ctx, src0, node);
2228
+ case GGML_OP_ARGSORT:
2229
+ case GGML_OP_TOP_K:
2230
+ // we reuse the same argsort implementation for top_k
2231
+ return ggml_webgpu_argsort(ctx, src0, node);
2232
+ case GGML_OP_CUMSUM:
2233
+ return ggml_webgpu_cumsum(ctx, src0, node);
2234
+ case GGML_OP_SUM:
2235
+ case GGML_OP_SUM_ROWS:
2236
+ return ggml_webgpu_sum_rows(ctx, src0, node);
1614
2237
  default:
1615
2238
  return std::nullopt;
1616
2239
  }
@@ -1619,39 +2242,57 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
1619
2242
  static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1620
2243
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
1621
2244
 
1622
- ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
2245
+ ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
1623
2246
  webgpu_context ctx = backend_ctx->webgpu_ctx;
1624
2247
 
1625
2248
  WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
1626
2249
 
1627
- ctx->inflight_threads++;
2250
+ std::vector<webgpu_command> commands;
2251
+ std::vector<webgpu_submission> subs;
2252
+ uint32_t num_batched_kernels = 0;
2253
+ bool contains_set_rows = false;
1628
2254
 
1629
- std::vector<webgpu_command> commands;
1630
- std::vector<webgpu_submission_futures> futures;
1631
2255
  for (int i = 0; i < cgraph->n_nodes; i++) {
2256
+ if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
2257
+ contains_set_rows = true;
2258
+ }
1632
2259
  if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
1633
2260
  commands.push_back(*cmd);
2261
+ num_batched_kernels += cmd.value().num_kernels;
1634
2262
  }
1635
- // compute the batch size based on the number of inflight threads
1636
- uint32_t inflight_threads = ctx->inflight_threads;
1637
- uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
1638
- WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1639
- if (commands.size() >= batch_size) {
1640
- futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
2263
+
2264
+ if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
2265
+ num_batched_kernels = 0;
2266
+ subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
1641
2267
  // Process events and check for completed submissions
1642
- ctx->instance.ProcessEvents();
1643
- ggml_backend_webgpu_wait(ctx, futures, false);
2268
+ ctx->global_ctx->instance.ProcessEvents();
2269
+ ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
1644
2270
  commands.clear();
1645
2271
  }
1646
2272
  }
1647
2273
  if (!commands.empty()) {
1648
- webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
1649
- futures.push_back(new_futures);
2274
+ subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
2275
+ commands.clear();
2276
+ }
2277
+
2278
+ // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
2279
+ if (contains_set_rows) {
2280
+ wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
2281
+ encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
2282
+ ctx->set_rows_host_error_buf.GetSize());
2283
+ wgpu::CommandBuffer set_rows_commands = encoder.Finish();
2284
+ ctx->global_ctx->queue.Submit(1, &set_rows_commands);
2285
+ ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
2286
+ ctx->set_rows_host_error_buf.GetSize());
2287
+ const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
2288
+ if (*error_data) {
2289
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
2290
+ }
2291
+ ctx->set_rows_host_error_buf.Unmap();
1650
2292
  }
1651
2293
 
1652
- ggml_backend_webgpu_wait(ctx, futures);
1653
- ctx->inflight_threads--;
1654
- WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
2294
+ ggml_backend_webgpu_wait(ctx->global_ctx, subs);
2295
+ WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
1655
2296
  return GGML_STATUS_SUCCESS;
1656
2297
  }
1657
2298
 
@@ -1678,7 +2319,10 @@ static ggml_backend_i ggml_backend_webgpu_i = {
1678
2319
 
1679
2320
  static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1680
2321
  ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
1681
- ctx->buffer.Destroy();
2322
+ if (ctx != nullptr && ctx->buffer != nullptr) {
2323
+ ctx->buffer.Destroy();
2324
+ delete ctx;
2325
+ }
1682
2326
  }
1683
2327
 
1684
2328
  // Returns the "fake" base pointer.
@@ -1693,7 +2337,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
1693
2337
  size_t offset,
1694
2338
  size_t size) {
1695
2339
  if (size == 0) {
1696
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
2340
+ WEBGPU_LOG_DEBUG(
2341
+ "ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
2342
+ "nothing to do.");
1697
2343
  return;
1698
2344
  }
1699
2345
 
@@ -1708,8 +2354,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
1708
2354
 
1709
2355
  // This is a trick to set all bytes of a u32 to the same 1 byte value.
1710
2356
  uint32_t val32 = (uint32_t) value * 0x01010101;
1711
- ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
1712
- WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx);
2357
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
2358
+ WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
1713
2359
  }
1714
2360
 
1715
2361
  static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
@@ -1718,15 +2364,14 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
1718
2364
  size_t offset,
1719
2365
  size_t size) {
1720
2366
  WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
1721
- ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1722
- webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
2367
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1723
2368
 
1724
2369
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1725
2370
  << ", " << offset << ", " << size << ")");
1726
2371
 
1727
2372
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
1728
2373
 
1729
- webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
2374
+ buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
1730
2375
 
1731
2376
  if (size % 4 != 0) {
1732
2377
  // If size is not a multiple of 4, we need to memset the remaining bytes
@@ -1739,21 +2384,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
1739
2384
  ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
1740
2385
  }
1741
2386
  // memset the remaining bytes
1742
- ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
1743
- remaining_size);
2387
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
2388
+ total_offset + (size - remaining_size), remaining_size);
1744
2389
  } else {
1745
2390
  // wait for WriteBuffer to complete
1746
- webgpu_ctx->instance.WaitAny(
1747
- webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
2391
+ buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
2392
+ wgpu::CallbackMode::AllowSpontaneous,
1748
2393
  [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
1749
2394
  if (status != wgpu::QueueWorkDoneStatus::Success) {
1750
2395
  GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
1751
2396
  std::string(message).c_str());
1752
2397
  }
1753
2398
  }),
1754
- UINT64_MAX);
2399
+ UINT64_MAX);
1755
2400
  }
1756
- WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx);
2401
+ WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
1757
2402
  }
1758
2403
 
1759
2404
  static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
@@ -1765,53 +2410,56 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
1765
2410
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1766
2411
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1767
2412
  << ", " << offset << ", " << size << ")");
1768
- webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
1769
- wgpu::Device device = webgpu_ctx->device;
2413
+ wgpu::Device device = buf_ctx->global_ctx->device;
1770
2414
 
1771
2415
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
1772
2416
 
1773
2417
  size_t final_size = size;
1774
2418
  if (size % 4 != 0) {
1775
- // If size is not a multiple of 4, we need to round it up to the next multiple of 4
2419
+ // If size is not a multiple of 4, we need to round it up to the next
2420
+ // multiple of 4
1776
2421
  final_size = size + (4 - (size % 4));
1777
2422
  }
1778
2423
 
1779
- std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
2424
+ std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
1780
2425
 
1781
- if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
2426
+ if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
2427
+ buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
1782
2428
  // Create a new staging buffer if it doesn't exist or is too small
1783
- if (webgpu_ctx->get_tensor_staging_buf) {
1784
- webgpu_ctx->get_tensor_staging_buf.Destroy();
2429
+ if (buf_ctx->global_ctx->get_tensor_staging_buf) {
2430
+ buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
1785
2431
  }
1786
- ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
2432
+ ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
1787
2433
  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
1788
2434
  }
1789
2435
 
1790
2436
  // Copy the data from the buffer to the staging buffer
1791
2437
  wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
1792
- encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
2438
+ encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
2439
+ final_size);
1793
2440
  wgpu::CommandBuffer commands = encoder.Finish();
1794
2441
 
1795
2442
  // Submit the command buffer to the queue
1796
- webgpu_ctx->queue.Submit(1, &commands);
2443
+ buf_ctx->global_ctx->queue.Submit(1, &commands);
1797
2444
 
1798
2445
  // Map the staging buffer to read the data
1799
- ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
2446
+ ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
2447
+ wgpu::MapMode::Read, 0, final_size);
1800
2448
  // Must specify size here since the staging buffer might be larger than the tensor size
1801
- const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
2449
+ const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
1802
2450
 
1803
2451
  // Copy the data from the mapped range to the output buffer
1804
2452
  std::memcpy(data, mapped_range, size);
1805
- webgpu_ctx->get_tensor_staging_buf.Unmap();
1806
- WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx);
2453
+ buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
2454
+ WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
1807
2455
  }
1808
2456
 
1809
2457
  static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1810
2458
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
1811
2459
  WEBGPU_CPU_PROFILE_TOTAL_START(clear);
1812
2460
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1813
- ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
1814
- WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx);
2461
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
2462
+ WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
1815
2463
  }
1816
2464
 
1817
2465
  static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
@@ -1823,7 +2471,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
1823
2471
  /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
1824
2472
  /* .cpy_tensor = */ NULL, // TODO: optional, implement this
1825
2473
  /* .clear = */ ggml_backend_webgpu_buffer_clear,
1826
- /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
2474
+ /* .reset = */ NULL, // TODO: optional, think it coordinates with
2475
+ // .init_tensor
1827
2476
  };
1828
2477
 
1829
2478
  /* End GGML Backend Buffer Interface */
@@ -1841,31 +2490,60 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
1841
2490
  int buffer_id = buffer_count++;
1842
2491
  std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
1843
2492
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
1844
- ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1845
2493
 
1846
- wgpu::Buffer buf;
1847
- ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
2494
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2495
+ wgpu::Buffer buf;
2496
+ ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
1848
2497
  wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
1849
2498
  buf_name.c_str());
1850
2499
 
1851
2500
  ggml_backend_webgpu_buffer_context * buf_ctx =
1852
- new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
2501
+ new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
1853
2502
 
1854
2503
  return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
1855
2504
  }
1856
2505
 
1857
2506
  static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1858
- ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1859
- return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
2507
+ ggml_backend_webgpu_device_context * dev_ctx =
2508
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2509
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
1860
2510
  }
1861
2511
 
1862
- // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
2512
+ // maxBufferSize might be larger, but you can't bind more than
2513
+ // maxStorageBufferBindingSize to a single binding.
1863
2514
  static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
1864
- ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1865
- return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
2515
+ ggml_backend_webgpu_device_context * dev_ctx =
2516
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2517
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
1866
2518
  }
1867
2519
 
1868
- /* End GGML Backend Buffer Type Interface */
2520
+ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
2521
+ const ggml_tensor * tensor) {
2522
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2523
+ size_t res = ggml_nbytes(tensor);
2524
+ switch (tensor->op) {
2525
+ case GGML_OP_ARGSORT:
2526
+ res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2527
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
2528
+ break;
2529
+ case GGML_OP_TOP_K:
2530
+ {
2531
+ const ggml_tensor * src0 = tensor->src[0];
2532
+ if (src0) {
2533
+ const size_t full = sizeof(int32_t) * ggml_nelements(src0);
2534
+ res = ROUNDUP_POW2(
2535
+ full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2536
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
2537
+ }
2538
+ }
2539
+ break;
2540
+ default:
2541
+ break;
2542
+ }
2543
+ return res;
2544
+ }
2545
+
2546
+ /* End GGML Backend Buffer Type Interface */
1869
2547
 
1870
2548
  /* GGML Backend Device Interface */
1871
2549
 
@@ -1883,7 +2561,7 @@ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t
1883
2561
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1884
2562
  // TODO: for now, return maxBufferSize as both free and total memory
1885
2563
  // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
1886
- uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
2564
+ uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
1887
2565
  // If we're on a 32-bit system, clamp to UINTPTR_MAX
1888
2566
  #if UINTPTR_MAX < UINT64_MAX
1889
2567
  uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
@@ -1918,329 +2596,64 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
1918
2596
  return reinterpret_cast<ggml_guid_t>((void *) guid_str);
1919
2597
  }
1920
2598
 
1921
- // Workgroup size is a common constant
1922
- static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
1923
- std::vector<wgpu::ConstantEntry> constants(1);
1924
- constants[0].key = "wg_size";
1925
- constants[0].value = wg_size;
1926
- return constants;
1927
- }
1928
-
1929
- static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
2599
+ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
1930
2600
  // we use the maximum workgroup size for the memset pipeline
1931
- size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
2601
+ size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1932
2602
  // Size the bytes_per_thread so that the largest buffer size can be handled
1933
- webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads);
2603
+ ctx->capabilities.memset_bytes_per_thread =
2604
+ CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
1934
2605
  std::vector<wgpu::ConstantEntry> constants(2);
1935
- constants[0].key = "wg_size";
1936
- constants[0].value = WEBGPU_MAX_WG_SIZE;
1937
- constants[1].key = "bytes_per_thread";
1938
- constants[1].value = webgpu_ctx->memset_bytes_per_thread;
1939
- webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants);
1940
- }
1941
-
1942
- static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
1943
- // Q4/Q5/Q8 classic quantizations
1944
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
1945
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
1946
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
1947
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
1948
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
1949
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
1950
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
1951
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
1952
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
1953
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
1954
-
1955
- // K-quantizations
1956
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
1957
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
1958
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
1959
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
1960
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
1961
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
1962
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
1963
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
1964
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
1965
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
1966
-
1967
- // IQ quantizations (2-, 3-, 4-bit variants)
1968
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
1969
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
1970
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
1971
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
1972
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
1973
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
1974
-
1975
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
1976
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
1977
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
1978
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
1979
-
1980
- // 1-bit and 4-bit IQ variants
1981
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
1982
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
1983
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
1984
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
1985
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
1986
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
1987
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
1988
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
1989
-
1990
- std::string proc_mul_mat_f32_f32;
1991
- std::string proc_mul_mat_f32_f32_vec;
1992
- std::string proc_mul_mat_f16_f32;
1993
- std::string proc_mul_mat_f16_f32_vec;
1994
- std::string proc_mul_mat_f16_f16;
1995
- std::string proc_mul_mat_f16_f16_vec;
1996
- std::string proc_mul_mat_q4_0_f32;
1997
- std::string proc_mul_mat_q4_0_f32_vec;
1998
-
1999
- std::vector<wgpu::ConstantEntry> mul_mat_constants;
2000
- #ifndef __EMSCRIPTEN__
2001
- if (webgpu_ctx->supports_subgroup_matrix) {
2002
- std::map<std::string, std::string> sg_matrix_repls;
2003
- sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
2004
- sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
2005
- sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
2006
- sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
2007
- sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
2008
- sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
2009
- sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
2010
- sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
2011
- sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
2012
-
2013
- proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
2014
- proc_mul_mat_f32_f32_vec =
2015
- ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
2016
- proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
2017
- proc_mul_mat_f16_f32_vec =
2018
- ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
2019
- proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
2020
- proc_mul_mat_f16_f16_vec =
2021
- ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
2022
- proc_mul_mat_q4_0_f32 =
2023
- ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
2024
- proc_mul_mat_q4_0_f32_vec =
2025
- ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
2026
- } else {
2027
- #endif
2028
- mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
2029
- mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
2030
- mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
2031
-
2032
- std::map<std::string, std::string> reg_repls;
2033
- reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
2034
- reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
2035
-
2036
- proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
2037
- proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
2038
- proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
2039
- proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
2040
- proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
2041
- proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
2042
- proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
2043
- proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
2044
- #ifndef __EMSCRIPTEN__
2045
- }
2046
- #endif
2047
-
2048
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2049
- webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
2050
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2051
- webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
2052
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2053
- webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
2054
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2055
- webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
2056
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2057
- webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
2058
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2059
- webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
2060
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2061
- webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
2062
- webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2063
- webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
2064
-
2065
- std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
2066
- mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
2067
- mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2068
- mul_mat_vec_constants[1].key = "TILE_K";
2069
- mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
2070
- mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
2071
- mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
2072
-
2073
- webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2074
- webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
2075
- webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2076
- webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
2077
- webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2078
- webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
2079
- webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2080
- webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
2081
- webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2082
- webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
2083
- webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2084
- webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
2085
- webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2086
- webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
2087
- }
2088
-
2089
- static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
2090
- webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline(
2091
- webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
2092
- webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline(
2093
- webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
2094
- }
2095
-
2096
- static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
2097
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2098
-
2099
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
2100
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
2101
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] =
2102
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
2103
-
2104
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
2105
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
2106
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
2107
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
2108
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
2109
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
2110
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
2111
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
2112
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
2113
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
2114
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
2115
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
2116
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
2117
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
2118
-
2119
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
2120
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
2121
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
2122
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
2123
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
2124
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
2125
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
2126
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
2127
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
2128
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
2129
-
2130
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] =
2131
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
2132
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
2133
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
2134
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
2135
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
2136
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] =
2137
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
2138
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
2139
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
2140
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
2141
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
2142
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
2143
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
2144
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
2145
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
2146
- webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
2147
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
2606
+ constants[0].key = "wg_size";
2607
+ constants[0].value = WEBGPU_MAX_WG_SIZE;
2608
+ constants[1].key = "bytes_per_thread";
2609
+ constants[1].value = ctx->capabilities.memset_bytes_per_thread;
2610
+ ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
2148
2611
  }
2149
2612
 
2150
2613
  static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
2151
2614
  std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2152
2615
 
2153
2616
  webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
2154
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2617
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2618
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
2619
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
2155
2620
  webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
2156
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2621
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2157
2622
  webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
2158
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2623
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2159
2624
  webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
2160
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
2161
- }
2162
-
2163
- static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
2164
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2165
-
2166
- webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
2167
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants);
2168
- webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
2169
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants);
2170
- webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
2171
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
2172
- webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
2173
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
2174
- }
2175
-
2176
- static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
2177
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2178
-
2179
- webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
2180
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants);
2181
- webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
2182
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants);
2183
- webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
2184
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
2185
- webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
2186
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
2187
- }
2188
-
2189
- static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
2190
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2191
-
2192
- webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
2193
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants);
2194
- webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
2195
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants);
2196
- webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
2197
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
2198
- webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
2199
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
2200
- }
2201
-
2202
- static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
2203
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2204
-
2205
- webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
2206
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants);
2207
- webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
2208
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants);
2209
- webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
2210
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
2211
- webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
2212
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
2625
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
2213
2626
  }
2214
2627
 
2215
2628
  static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
2216
2629
  std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2217
2630
 
2218
2631
  webgpu_ctx->rms_norm_pipelines[0] =
2219
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2220
- webgpu_ctx->rms_norm_pipelines[1] =
2221
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2632
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2633
+ webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
2634
+ webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2222
2635
  }
2223
2636
 
2224
2637
  static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
2225
2638
  std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2226
2639
 
2227
2640
  webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
2228
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2229
- webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] =
2230
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2641
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2642
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
2643
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2231
2644
  webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
2232
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2233
- webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] =
2234
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2645
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2646
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
2647
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2235
2648
 
2236
2649
  webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
2237
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2238
- webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] =
2239
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2650
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2651
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
2652
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2240
2653
  webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
2241
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2242
- webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] =
2243
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2654
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2655
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
2656
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2244
2657
  }
2245
2658
 
2246
2659
  static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
@@ -2248,242 +2661,59 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
2248
2661
 
2249
2662
  // REGLU
2250
2663
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
2251
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2664
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2252
2665
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
2253
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2666
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2254
2667
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
2255
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2668
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2256
2669
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
2257
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2670
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2258
2671
 
2259
2672
  // GEGLU
2260
2673
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
2261
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2674
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2262
2675
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
2263
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2676
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2264
2677
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
2265
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2678
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2266
2679
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
2267
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2680
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2268
2681
 
2269
2682
  // SWIGLU
2270
2683
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
2271
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2684
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2272
2685
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
2273
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2274
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] =
2275
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2276
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] =
2277
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2686
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2687
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2688
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2689
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2690
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2278
2691
 
2279
2692
  // SWIGLU_OAI
2280
2693
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
2281
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2282
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] =
2283
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2694
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2695
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2696
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2284
2697
 
2285
2698
  // GEGLU_ERF
2286
2699
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
2287
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2700
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2288
2701
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
2289
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2290
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] =
2291
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2292
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] =
2293
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2702
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2703
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2704
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2705
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2706
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2294
2707
 
2295
2708
  // GEGLU_QUICK
2296
2709
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
2297
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2710
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2298
2711
  webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
2299
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2300
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] =
2301
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2302
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] =
2303
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2304
- }
2305
-
2306
- static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
2307
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2308
-
2309
- // ABS
2310
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] =
2311
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants);
2312
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] =
2313
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants);
2314
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] =
2315
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants);
2316
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] =
2317
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants);
2318
-
2319
- // SGN
2320
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] =
2321
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants);
2322
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] =
2323
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants);
2324
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] =
2325
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants);
2326
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] =
2327
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants);
2328
-
2329
- // NEG
2330
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] =
2331
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants);
2332
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] =
2333
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants);
2334
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] =
2335
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants);
2336
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] =
2337
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants);
2338
-
2339
- // STEP
2340
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] =
2341
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants);
2342
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] =
2343
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants);
2344
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] =
2345
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants);
2346
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] =
2347
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants);
2348
-
2349
- // TANH
2350
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] =
2351
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants);
2352
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] =
2353
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants);
2354
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] =
2355
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants);
2356
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] =
2357
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants);
2358
-
2359
- // ELU
2360
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] =
2361
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants);
2362
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] =
2363
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants);
2364
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] =
2365
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants);
2366
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] =
2367
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants);
2368
-
2369
- // RELU
2370
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] =
2371
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants);
2372
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] =
2373
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants);
2374
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] =
2375
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants);
2376
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] =
2377
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants);
2378
-
2379
- // SIGMOID
2380
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] =
2381
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants);
2382
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] =
2383
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants);
2384
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] =
2385
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants);
2386
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] =
2387
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants);
2388
-
2389
- // GELU
2390
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] =
2391
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants);
2392
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] =
2393
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants);
2394
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] =
2395
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants);
2396
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] =
2397
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants);
2398
-
2399
- // GELU_QUICK
2400
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] =
2401
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
2402
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] =
2403
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
2404
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2405
- webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants);
2406
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2407
- webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants);
2408
-
2409
- // SILU
2410
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] =
2411
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants);
2412
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] =
2413
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants);
2414
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] =
2415
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants);
2416
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] =
2417
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants);
2418
-
2419
- // HARDSWISH
2420
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] =
2421
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants);
2422
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] =
2423
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants);
2424
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] =
2425
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants);
2426
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] =
2427
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants);
2428
-
2429
- // HARDSIGMOID
2430
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] =
2431
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
2432
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] =
2433
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
2434
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2435
- webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants);
2436
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2437
- webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants);
2438
-
2439
- // EXP
2440
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] =
2441
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants);
2442
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] =
2443
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants);
2444
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] =
2445
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants);
2446
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] =
2447
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants);
2448
-
2449
- // GELU_ERF
2450
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] =
2451
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
2452
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] =
2453
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
2454
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] =
2455
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants);
2456
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] =
2457
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants);
2458
-
2459
- // XIELU
2460
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] =
2461
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants);
2462
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] =
2463
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants);
2464
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] =
2465
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants);
2466
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] =
2467
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants);
2468
-
2469
- // CEIL
2470
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] =
2471
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants);
2472
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] =
2473
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants);
2474
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] =
2475
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants);
2476
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] =
2477
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants);
2478
- }
2479
-
2480
- static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
2481
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2482
-
2483
- webgpu_ctx->scale_pipelines[0] =
2484
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants);
2485
- webgpu_ctx->scale_pipelines[1] =
2486
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants);
2712
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2713
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2714
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2715
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2716
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2487
2717
  }
2488
2718
 
2489
2719
  static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
@@ -2491,56 +2721,239 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
2491
2721
 
2492
2722
  // f32 (no mask)
2493
2723
  webgpu_ctx->soft_max_pipelines[2][0][0] =
2494
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2495
- webgpu_ctx->soft_max_pipelines[2][0][1] =
2496
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2497
- webgpu_ctx->soft_max_pipelines[2][1][0] =
2498
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2724
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2725
+ webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
2726
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2727
+ webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
2728
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2499
2729
  webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
2500
- webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2730
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2501
2731
 
2502
2732
  // f32 mask (mask_type = 0)
2503
- webgpu_ctx->soft_max_pipelines[0][0][0] =
2504
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2733
+ webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
2734
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2505
2735
  webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
2506
- webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2736
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2507
2737
  webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
2508
- webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2509
- webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline(
2510
- webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants);
2738
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2739
+ webgpu_ctx->soft_max_pipelines[0][1][1] =
2740
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
2741
+ "soft_max_f32_mask_f32_sink_inplace", constants);
2511
2742
 
2512
2743
  // f16 mask (mask_type = 1)
2513
- webgpu_ctx->soft_max_pipelines[1][0][0] =
2514
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2744
+ webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
2745
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2515
2746
  webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
2516
- webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2747
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2517
2748
  webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
2518
- webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2519
- webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline(
2520
- webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
2749
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2750
+ webgpu_ctx->soft_max_pipelines[1][1][1] =
2751
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
2752
+ "soft_max_f32_mask_f16_sink_inplace", constants);
2521
2753
  }
2522
2754
 
2523
- // TODO: move most initialization logic here
2524
- static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
2755
+ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
2756
+ wgpu::RequestAdapterOptions options = {};
2757
+
2758
+ #ifndef __EMSCRIPTEN__
2759
+ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2760
+ const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2761
+ wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2762
+ adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
2763
+ adapterTogglesDesc.enabledToggleCount = 2;
2764
+ options.nextInChain = &adapterTogglesDesc;
2765
+ #endif
2766
+
2767
+ ctx->webgpu_global_ctx->instance.WaitAny(
2768
+ ctx->webgpu_global_ctx->instance.RequestAdapter(
2769
+ &options, wgpu::CallbackMode::AllowSpontaneous,
2770
+ [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
2771
+ if (status != wgpu::RequestAdapterStatus::Success) {
2772
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
2773
+ return;
2774
+ }
2775
+ ctx->webgpu_global_ctx->adapter = std::move(adapter);
2776
+ }),
2777
+ UINT64_MAX);
2778
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
2779
+
2780
+ ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
2781
+
2782
+ wgpu::AdapterInfo info{};
2783
+ #ifndef __EMSCRIPTEN__
2784
+ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2785
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2786
+ info.nextInChain = &subgroup_matrix_configs;
2787
+ }
2788
+ #endif
2789
+ ctx->webgpu_global_ctx->adapter.GetInfo(&info);
2790
+ wgpu::SupportedFeatures features;
2791
+ ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
2792
+ // we require f16 support
2793
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2794
+
2795
+ #ifndef __EMSCRIPTEN__
2796
+ // Only support square f16 matrices of size 8 or 16 for now
2797
+ bool valid_subgroup_matrix_config = false;
2798
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2799
+ for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2800
+ const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2801
+ if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2802
+ config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2803
+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2804
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
2805
+ ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
2806
+ ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
2807
+ valid_subgroup_matrix_config = true;
2808
+ break;
2809
+ }
2810
+ }
2811
+ }
2812
+ ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
2813
+ #endif
2814
+
2815
+ // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2816
+ // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2817
+ ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
2818
+ // Initialize device
2819
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
2820
+
2821
+ #ifndef __EMSCRIPTEN__
2822
+ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2823
+ if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
2824
+ required_features.push_back(wgpu::FeatureName::Subgroups);
2825
+ required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2826
+ }
2827
+ #endif
2828
+
2829
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2830
+ required_features.push_back(wgpu::FeatureName::TimestampQuery);
2831
+ #endif
2832
+
2833
+ wgpu::DeviceDescriptor dev_desc;
2834
+ dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits;
2835
+ dev_desc.requiredFeatures = required_features.data();
2836
+ dev_desc.requiredFeatureCount = required_features.size();
2837
+ dev_desc.SetDeviceLostCallback(
2838
+ wgpu::CallbackMode::AllowSpontaneous,
2839
+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
2840
+ if (reason == wgpu::DeviceLostReason::Destroyed) {
2841
+ return;
2842
+ }
2843
+ GGML_UNUSED(device);
2844
+ GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2845
+ std::string(message).c_str());
2846
+ });
2847
+ dev_desc.SetUncapturedErrorCallback(
2848
+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
2849
+ GGML_UNUSED(device);
2850
+ GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2851
+ std::string(message).c_str());
2852
+ });
2853
+
2854
+ #ifndef __EMSCRIPTEN__
2855
+ // Enable Dawn-specific toggles to increase native performance
2856
+ // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2857
+ // only for native performance?
2858
+ const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2859
+ "disable_polyfills_on_integer_div_and_mod" };
2860
+ const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2861
+ wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2862
+ deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2863
+ deviceTogglesDesc.enabledToggleCount = 4;
2864
+ deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2865
+ deviceTogglesDesc.disabledToggleCount = 1;
2866
+
2867
+ dev_desc.nextInChain = &deviceTogglesDesc;
2868
+ #endif
2869
+
2870
+ ctx->webgpu_global_ctx->instance.WaitAny(
2871
+ ctx->webgpu_global_ctx->adapter.RequestDevice(
2872
+ &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
2873
+ [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
2874
+ if (status != wgpu::RequestDeviceStatus::Success) {
2875
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
2876
+ return;
2877
+ }
2878
+ ctx->webgpu_global_ctx->device = std::move(device);
2879
+ }),
2880
+ UINT64_MAX);
2881
+ GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
2882
+
2883
+ ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
2884
+ ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2885
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2886
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2887
+ ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
2888
+
2889
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2890
+ // Initialize buffer pool for timestamp queries, used for profiling
2891
+ ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
2892
+ ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2893
+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2894
+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2895
+ #endif
2896
+
2897
+ GGML_LOG_INFO(
2898
+ "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
2899
+ "device_desc: %s\n",
2900
+ info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
2901
+ std::string(info.device).c_str(), std::string(info.description).c_str());
2902
+ return true;
2903
+ }
2904
+
2905
+ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
2906
+ ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
2907
+ webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
2908
+ webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
2909
+ webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
2910
+ webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2911
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2912
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
2913
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
2914
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
2915
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
2916
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
2917
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
2918
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
2919
+
2920
+ ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
2921
+ ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
2922
+ ggml_webgpu_init_rope_pipeline(webgpu_ctx);
2923
+ ggml_webgpu_init_glu_pipeline(webgpu_ctx);
2924
+ ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
2925
+ #ifdef GGML_WEBGPU_DEBUG
2926
+ // Initialize debug buffers
2927
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
2928
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
2929
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
2930
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
2931
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
2932
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
2933
+ #endif
2934
+ return webgpu_ctx;
2935
+ }
2936
+
2937
+ static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
2525
2938
  GGML_UNUSED(params);
2526
2939
 
2527
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
2940
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
2528
2941
 
2529
- ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2530
- webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
2942
+ ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2531
2943
 
2532
- static ggml_backend_webgpu_context backend_ctx;
2533
- backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
2534
- backend_ctx.webgpu_ctx = webgpu_ctx;
2944
+ auto * backend_ctx = new ggml_backend_webgpu_context();
2945
+ backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
2946
+ backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
2535
2947
 
2536
2948
  // See GGML Backend Interface section
2537
- static ggml_backend backend = {
2949
+ auto * backend = new ggml_backend();
2950
+ *backend = {
2538
2951
  /* .guid = */ ggml_backend_webgpu_guid(),
2539
2952
  /* .interface = */ ggml_backend_webgpu_i,
2540
2953
  /* .device = */ dev,
2541
- /* .context = */ &backend_ctx,
2954
+ /* .context = */ backend_ctx,
2542
2955
  };
2543
- return &backend;
2956
+ return backend;
2544
2957
  }
2545
2958
 
2546
2959
  static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -2549,15 +2962,16 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
2549
2962
  static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
2550
2963
  /* .iface = */ {
2551
2964
  /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
2552
- /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
2553
- /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
2554
- /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
2555
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2556
- /* .is_host = */ NULL, // defaults to false
2965
+ /* .alloc_buffer = */
2966
+ ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
2967
+ ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
2968
+ ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
2969
+ ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
2557
2970
  },
2558
2971
  /* .device = */
2559
2972
  dev,
2560
- /* .context = */ NULL,
2973
+ /* .context = */
2974
+ NULL
2561
2975
  };
2562
2976
 
2563
2977
  return &ggml_backend_webgpu_buffer_type;
@@ -2598,16 +3012,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
2598
3012
  static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2599
3013
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2600
3014
 
2601
- webgpu_context webgpu_ctx = ctx->webgpu_ctx;
2602
-
2603
3015
  ggml_tensor * src0 = op->src[0];
2604
3016
  ggml_tensor * src1 = op->src[1];
2605
3017
  ggml_tensor * src2 = op->src[2];
2606
3018
 
2607
3019
  // on smaller devices (or CI), tensors may be larger than the max storage buffer size
2608
- if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
2609
- (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2610
- (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
3020
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3021
+ (src0 != nullptr &&
3022
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3023
+ (src1 != nullptr &&
3024
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
2611
3025
  return false;
2612
3026
  }
2613
3027
 
@@ -2624,23 +3038,30 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
2624
3038
  case GGML_OP_SUB:
2625
3039
  case GGML_OP_MUL:
2626
3040
  case GGML_OP_DIV:
2627
- // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
2628
- // see https://github.com/ggml-org/llama.cpp/pull/16857
2629
3041
  supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
2630
- (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
3042
+ (src1->type == op->type);
3043
+ break;
3044
+ case GGML_OP_CONCAT:
3045
+ supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
3046
+ break;
3047
+ case GGML_OP_REPEAT:
3048
+ supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
2631
3049
  break;
2632
3050
  case GGML_OP_CPY:
2633
3051
  case GGML_OP_CONT:
2634
- supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
2635
- (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3052
+ supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3053
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
3054
+ (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
2636
3055
  break;
2637
3056
  case GGML_OP_SET_ROWS:
2638
- supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
3057
+ supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
3058
+ (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
2639
3059
  break;
2640
3060
  case GGML_OP_GET_ROWS:
2641
- if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
2642
- ggml_webgpu_supported_qtype(src0->type)) {
3061
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
2643
3062
  supports_op = (op->type == GGML_TYPE_F32);
3063
+ } else if (src0->type == GGML_TYPE_I32) {
3064
+ supports_op = op->type == GGML_TYPE_I32;
2644
3065
  }
2645
3066
  break;
2646
3067
  case GGML_OP_MUL_MAT:
@@ -2684,17 +3105,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
2684
3105
  }
2685
3106
  case GGML_OP_FLASH_ATTN_EXT:
2686
3107
  {
2687
- if (!webgpu_ctx->supports_subgroup_matrix) {
3108
+ #ifndef __EMSCRIPTEN__
3109
+ if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
2688
3110
  break;
2689
3111
  }
2690
3112
  // Head dimensions must fit in workgroup memory with minimum tile sizes
2691
- size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
3113
+ size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
2692
3114
  const bool has_mask = op->src[3] != nullptr;
2693
- const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
3115
+ const bool kv_direct = src1->type == GGML_TYPE_F16 &&
3116
+ (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
2694
3117
  (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
2695
3118
  const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
2696
- webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
2697
- has_mask, kv_direct);
3119
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
3120
+ (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
2698
3121
  if (min_bytes > limit_bytes) {
2699
3122
  break;
2700
3123
  }
@@ -2703,6 +3126,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
2703
3126
  (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
2704
3127
  src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
2705
3128
  src2->type == src1->type && op->type == GGML_TYPE_F32;
3129
+ #endif
2706
3130
  break;
2707
3131
  }
2708
3132
  case GGML_OP_RMS_NORM:
@@ -2753,9 +3177,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
2753
3177
  case GGML_UNARY_OP_HARDSIGMOID:
2754
3178
  case GGML_UNARY_OP_EXP:
2755
3179
  case GGML_UNARY_OP_GELU_ERF:
2756
- case GGML_UNARY_OP_XIELU:
3180
+ case GGML_UNARY_OP_SOFTPLUS:
3181
+ case GGML_UNARY_OP_EXPM1:
3182
+ case GGML_UNARY_OP_FLOOR:
2757
3183
  case GGML_UNARY_OP_CEIL:
2758
- supports_op = supports_op =
3184
+ case GGML_UNARY_OP_ROUND:
3185
+ case GGML_UNARY_OP_TRUNC:
3186
+ case GGML_UNARY_OP_XIELU:
3187
+ supports_op =
2759
3188
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2760
3189
  break;
2761
3190
  default:
@@ -2763,14 +3192,56 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
2763
3192
  }
2764
3193
  }
2765
3194
  break;
2766
-
3195
+ case GGML_OP_CLAMP:
3196
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3197
+ break;
3198
+ case GGML_OP_FILL:
3199
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3200
+ break;
3201
+ case GGML_OP_LOG:
3202
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3203
+ break;
3204
+ case GGML_OP_SQR:
3205
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3206
+ break;
3207
+ case GGML_OP_SQRT:
3208
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3209
+ break;
3210
+ case GGML_OP_SIN:
3211
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3212
+ break;
3213
+ case GGML_OP_COS:
3214
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3215
+ break;
3216
+ case GGML_OP_PAD:
3217
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3218
+ break;
3219
+ case GGML_OP_ARGMAX:
3220
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
3221
+ break;
3222
+ case GGML_OP_ARGSORT:
3223
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3224
+ break;
3225
+ case GGML_OP_TOP_K:
3226
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3227
+ break;
3228
+ case GGML_OP_CUMSUM:
3229
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
3230
+ break;
3231
+ case GGML_OP_SUM:
3232
+ case GGML_OP_SUM_ROWS:
3233
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
3234
+ break;
2767
3235
  default:
2768
3236
  break;
2769
3237
  }
2770
- if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
2771
- (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2772
- (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2773
- (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
3238
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3239
+ (src0 != nullptr &&
3240
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3241
+ (src1 != nullptr &&
3242
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3243
+ (src2 != nullptr &&
3244
+ ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
2774
3245
  supports_op = false;
2775
3246
  WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
2776
3247
  }
@@ -2795,7 +3266,7 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
2795
3266
  /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
2796
3267
  /* .get_type = */ ggml_backend_webgpu_device_get_type,
2797
3268
  /* .get_props = */ ggml_backend_webgpu_device_get_props,
2798
- /* .init_backend = */ ggml_backend_webgpu_device_init,
3269
+ /* .init_backend = */ ggml_backend_webgpu_backend_init,
2799
3270
  /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
2800
3271
  /* .get_host_buffer_type = */ NULL,
2801
3272
  /* .buffer_from_host_ptr = */ NULL,
@@ -2821,8 +3292,6 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
2821
3292
  return ctx->device_count;
2822
3293
  }
2823
3294
 
2824
- // TODO: Does this need to be thread safe? Is it only called once?
2825
- // TODO: move most logic to device_init function so backend can be freed/initialized properly
2826
3295
  // Only one device is supported for now
2827
3296
  static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2828
3297
  GGML_ASSERT(index == 0);
@@ -2832,191 +3301,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
2832
3301
 
2833
3302
  ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
2834
3303
 
2835
- webgpu_context ctx = reg_ctx->webgpu_ctx;
2836
-
2837
- wgpu::RequestAdapterOptions options = {};
2838
-
2839
- #ifndef __EMSCRIPTEN__
2840
- // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2841
- const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2842
- wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2843
- adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
2844
- adapterTogglesDesc.enabledToggleCount = 2;
2845
- options.nextInChain = &adapterTogglesDesc;
2846
- #endif
2847
-
2848
- ctx->instance.WaitAny(ctx->instance.RequestAdapter(
2849
- &options, wgpu::CallbackMode::AllowSpontaneous,
2850
- [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
2851
- if (status != wgpu::RequestAdapterStatus::Success) {
2852
- GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
2853
- return;
2854
- }
2855
- ctx->adapter = std::move(adapter);
2856
- }),
2857
- UINT64_MAX);
2858
- GGML_ASSERT(ctx->adapter != nullptr);
2859
-
2860
- ctx->adapter.GetLimits(&ctx->limits);
2861
-
2862
- wgpu::AdapterInfo info{};
2863
- #ifndef __EMSCRIPTEN__
2864
- wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2865
- if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2866
- info.nextInChain = &subgroup_matrix_configs;
2867
- }
2868
- #endif
2869
- ctx->adapter.GetInfo(&info);
2870
-
2871
- wgpu::SupportedFeatures features;
2872
- ctx->adapter.GetFeatures(&features);
2873
- // we require f16 support
2874
- GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2875
-
2876
- #ifndef __EMSCRIPTEN__
2877
- // Only support square f16 matrices of size 8 or 16 for now
2878
- bool valid_subgroup_matrix_config = false;
2879
- if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2880
- for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2881
- const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2882
- if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2883
- config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2884
- config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2885
- ctx->sg_mat_m = config.M;
2886
- ctx->sg_mat_n = config.N;
2887
- ctx->sg_mat_k = config.K;
2888
- valid_subgroup_matrix_config = true;
2889
- break;
2890
- }
2891
- }
2892
- }
2893
-
2894
- ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
2895
- #endif
2896
- // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2897
- // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2898
- ctx->max_subgroup_size = info.subgroupMaxSize;
2899
-
2900
- // Initialize device
2901
- std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
2902
-
2903
- #ifndef __EMSCRIPTEN__
2904
- required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2905
- if (ctx->supports_subgroup_matrix) {
2906
- required_features.push_back(wgpu::FeatureName::Subgroups);
2907
- required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2908
- }
2909
- #endif
2910
-
2911
- #ifdef GGML_WEBGPU_GPU_PROFILE
2912
- required_features.push_back(wgpu::FeatureName::TimestampQuery);
2913
- #endif
2914
-
2915
- wgpu::DeviceDescriptor dev_desc;
2916
- dev_desc.requiredLimits = &ctx->limits;
2917
- dev_desc.requiredFeatures = required_features.data();
2918
- dev_desc.requiredFeatureCount = required_features.size();
2919
- dev_desc.SetDeviceLostCallback(
2920
- wgpu::CallbackMode::AllowSpontaneous,
2921
- [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
2922
- GGML_UNUSED(device);
2923
- GGML_UNUSED(reason);
2924
- GGML_UNUSED(message);
2925
- //TODO: uncomment once proper free logic is in place
2926
- //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2927
- //std::string(message).c_str());
2928
- });
2929
- dev_desc.SetUncapturedErrorCallback(
2930
- [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
2931
- GGML_UNUSED(device);
2932
- GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2933
- std::string(message).c_str());
2934
- });
2935
-
2936
- #ifndef __EMSCRIPTEN__
2937
- // Enable Dawn-specific toggles to increase native performance
2938
- // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2939
- // only for native performance?
2940
- const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2941
- "disable_polyfills_on_integer_div_and_mod" };
2942
- const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2943
- wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2944
- deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2945
- deviceTogglesDesc.enabledToggleCount = 4;
2946
- deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2947
- deviceTogglesDesc.disabledToggleCount = 1;
2948
-
2949
- dev_desc.nextInChain = &deviceTogglesDesc;
2950
- #endif
2951
-
2952
- ctx->instance.WaitAny(ctx->adapter.RequestDevice(
2953
- &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
2954
- [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
2955
- if (status != wgpu::RequestDeviceStatus::Success) {
2956
- GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
2957
- std::string(message).c_str());
2958
- return;
2959
- }
2960
- ctx->device = std::move(device);
2961
- }),
2962
- UINT64_MAX);
2963
- GGML_ASSERT(ctx->device != nullptr);
2964
-
2965
- // Initialize (compute) queue
2966
- ctx->queue = ctx->device.GetQueue();
2967
-
2968
- // Create buffer pool for shader parameters
2969
- ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2970
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2971
- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2972
-
2973
- #ifdef GGML_WEBGPU_GPU_PROFILE
2974
- // Initialize buffer pool for timestamp queries (profiling)
2975
- ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
2976
- WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2977
- wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2978
- wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2979
- #endif
2980
-
2981
- ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
2982
- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
2983
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
2984
-
2985
- ggml_webgpu_init_memset_pipeline(ctx);
2986
- ggml_webgpu_init_mul_mat_pipeline(ctx);
2987
- ggml_webgpu_init_set_rows_pipeline(ctx);
2988
- ggml_webgpu_init_get_rows_pipeline(ctx);
2989
- ggml_webgpu_init_cpy_pipeline(ctx);
2990
- ggml_webgpu_init_add_pipeline(ctx);
2991
- ggml_webgpu_init_sub_pipeline(ctx);
2992
- ggml_webgpu_init_mul_pipeline(ctx);
2993
- ggml_webgpu_init_div_pipeline(ctx);
2994
- ggml_webgpu_init_rms_norm_pipeline(ctx);
2995
- ggml_webgpu_init_rope_pipeline(ctx);
2996
- ggml_webgpu_init_glu_pipeline(ctx);
2997
- ggml_webgpu_init_scale_pipeline(ctx);
2998
- ggml_webgpu_init_soft_max_pipeline(ctx);
2999
- ggml_webgpu_init_unary_pipeline(ctx);
3000
-
3001
- #ifdef GGML_WEBGPU_DEBUG
3002
- // Initialize debug buffers
3003
- ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3004
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
3005
- ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3006
- wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
3007
- #endif
3304
+ create_webgpu_device(reg_ctx);
3008
3305
 
3009
3306
  static ggml_backend_webgpu_device_context device_ctx;
3010
- device_ctx.webgpu_ctx = ctx;
3011
- device_ctx.device_name = GGML_WEBGPU_NAME;
3012
- device_ctx.device_desc = info.description;
3013
-
3014
- GGML_LOG_INFO(
3015
- "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
3016
- "device_desc: %s\n",
3017
- info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
3018
- std::string(info.device).c_str(), std::string(info.description).c_str());
3019
-
3307
+ device_ctx.device_name = GGML_WEBGPU_NAME;
3308
+ device_ctx.device_desc = GGML_WEBGPU_NAME;
3309
+ device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx;
3020
3310
  // See GGML Backend Device Interface section
3021
3311
  static ggml_backend_device device = {
3022
3312
  /* .iface = */ ggml_backend_webgpu_device_i,
@@ -3024,7 +3314,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
3024
3314
  /* .context = */ &device_ctx,
3025
3315
  };
3026
3316
 
3027
- WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx);
3317
+ WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
3028
3318
  return &device;
3029
3319
  }
3030
3320
 
@@ -3040,10 +3330,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
3040
3330
  ggml_backend_reg_t ggml_backend_webgpu_reg() {
3041
3331
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
3042
3332
 
3043
- webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
3044
-
3045
3333
  static ggml_backend_webgpu_reg_context ctx;
3046
- ctx.webgpu_ctx = webgpu_ctx;
3047
3334
  ctx.name = GGML_WEBGPU_NAME;
3048
3335
  ctx.device_count = 1;
3049
3336
 
@@ -3060,15 +3347,17 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
3060
3347
  instance_descriptor.nextInChain = &instanceTogglesDesc;
3061
3348
  #endif
3062
3349
 
3063
- webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
3350
+ wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor);
3351
+ ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
3352
+ ctx.webgpu_global_ctx->instance = std::move(inst);
3064
3353
 
3065
3354
  #ifdef __EMSCRIPTEN__
3066
- if (webgpu_ctx->instance == nullptr) {
3355
+ if (ctx.webgpu_global_ctx->instance == nullptr) {
3067
3356
  GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
3068
3357
  return nullptr;
3069
3358
  }
3070
3359
  #endif
3071
- GGML_ASSERT(webgpu_ctx->instance != nullptr);
3360
+ GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
3072
3361
 
3073
3362
  static ggml_backend_reg reg = {
3074
3363
  /* .api_version = */ GGML_BACKEND_API_VERSION,
@@ -3081,7 +3370,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
3081
3370
  ggml_backend_t ggml_backend_webgpu_init(void) {
3082
3371
  ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
3083
3372
 
3084
- return ggml_backend_webgpu_device_init(dev, nullptr);
3373
+ return ggml_backend_webgpu_backend_init(dev, nullptr);
3085
3374
  }
3086
3375
 
3087
3376
  GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)