whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -7,37 +7,102 @@
7
7
 
8
8
  #include "ggml-backend-impl.h"
9
9
  #include "ggml-impl.h"
10
- #include "ggml-wgsl-shaders.hpp"
10
+ #include "ggml-webgpu-shader-lib.hpp"
11
+
12
+ #ifdef __EMSCRIPTEN__
13
+ # include <emscripten/emscripten.h>
14
+ #endif
11
15
 
12
16
  #include <webgpu/webgpu_cpp.h>
13
17
 
18
+ #include <atomic>
14
19
  #include <condition_variable>
20
+ #include <cstdint>
15
21
  #include <cstring>
16
- #include <iostream>
22
+ #ifdef GGML_WEBGPU_GPU_PROFILE
23
+ # include <iomanip>
24
+ #endif
25
+ #if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
26
+ # include <iostream>
27
+ #endif
28
+ #include <map>
29
+ #include <memory>
17
30
  #include <mutex>
31
+ #include <optional>
18
32
  #include <string>
33
+ #include <utility>
19
34
  #include <vector>
20
35
 
36
+ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
37
+ #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
38
+
39
+ // Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
40
+ // Assumes that the total number of workgroups does not exceed max_per_dim^2.
41
+ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
42
+ wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
43
+ wg_x = CEIL_DIV(total_wg, wg_y);
44
+ }
45
+
21
46
  #ifdef GGML_WEBGPU_DEBUG
22
47
  # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
23
- # define WEBGPU_DEBUG_BUF_ELEMS 32
48
+ # define WEBGPU_DEBUG_BUF_ELEMS 512
24
49
  #else
25
50
  # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
26
51
  #endif // GGML_WEBGPU_DEBUG
27
52
 
53
+ #ifdef GGML_WEBGPU_CPU_PROFILE
54
+ // total timing (aggregated)
55
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
56
+
57
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
58
+ auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
59
+ double cpu_total_time_##id = \
60
+ std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
61
+ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
62
+ // fine-grained timing (not included in totals)
63
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
64
+
65
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
66
+ auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
67
+ double cpu_detail_time_##id = \
68
+ std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
69
+ (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
70
+ #else
71
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id)
72
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
73
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id)
74
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
75
+ #endif // GGML_WEBGPU_CPU_PROFILE
76
+
77
+ #ifdef GGML_WEBGPU_GPU_PROFILE
78
+ # define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
79
+ # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
80
+ #endif
81
+
28
82
  /* Constants */
29
83
 
30
- #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31
- #define WEBGPU_MUL_MAT_WG_SIZE 64
32
- #define WEBGPU_NUM_PARAM_BUFS 100
84
+ #define WEBGPU_NUM_PARAM_BUFS 96u
85
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
86
+ #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
87
+ // Maximum number of in-flight submissions per-thread, to avoid exhausting the
88
+ // parameter buffer pool
89
+ #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
33
90
  #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34
- #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35
91
  #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36
- #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
92
+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
93
+
94
+ // For operations which process a row in parallel, this seems like a reasonable
95
+ // default
96
+ #define WEBGPU_ROW_SPLIT_WG_SIZE 64
97
+
98
+ // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
99
+ // implementations so this can be removed, necessary only for get_rows right now
100
+ #define WEBGPU_MAX_WG_SIZE 288
37
101
 
38
102
  /* End Constants */
39
103
 
40
- // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
104
+ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to
105
+ // their locations.
41
106
  static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
42
107
 
43
108
  // Always returns the base offset of a tensor, regardless of views.
@@ -57,14 +122,98 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
57
122
  wgpu::BufferUsage usage,
58
123
  const char * label);
59
124
 
60
- struct webgpu_pool_bufs {
61
- wgpu::Buffer host_buf;
62
- wgpu::Buffer dev_buf;
125
+ // Holds a pool of parameter buffers for WebGPU operations
126
+ struct webgpu_buf_pool {
127
+ std::vector<wgpu::Buffer> free;
128
+
129
+ // The pool must be synchronized because
130
+ // 1. The memset pool is shared globally by every ggml buffer,
131
+ // since allocating a pool per ggml buffer would consume too much memory.
132
+ // 2. For the per-thread buffer pools in webgpu_context,
133
+ // buffers are allocated and freed in Dawn callbacks,
134
+ // which can run on a different thread than the calling thread.
135
+ std::mutex mutex;
136
+ std::condition_variable cv;
137
+ size_t cur_pool_size;
138
+ size_t max_pool_size;
139
+ wgpu::Device device;
140
+ wgpu::BufferUsage dev_buf_usage;
141
+ size_t buf_size;
142
+ bool should_grow;
143
+
144
+ void init(wgpu::Device device,
145
+ int num_bufs,
146
+ size_t buf_size,
147
+ wgpu::BufferUsage dev_buf_usage,
148
+ bool should_grow = false,
149
+ size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
150
+ this->max_pool_size = max_pool_size;
151
+ this->cur_pool_size = num_bufs;
152
+ this->device = device;
153
+ this->dev_buf_usage = dev_buf_usage;
154
+ this->buf_size = buf_size;
155
+ this->should_grow = should_grow;
156
+ for (int i = 0; i < num_bufs; i++) {
157
+ wgpu::Buffer dev_buf;
158
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
159
+ free.push_back(dev_buf);
160
+ }
161
+ }
162
+
163
+ wgpu::Buffer alloc_bufs() {
164
+ std::unique_lock<std::mutex> lock(mutex);
165
+ if (!free.empty()) {
166
+ wgpu::Buffer buf = free.back();
167
+ free.pop_back();
168
+ return buf;
169
+ }
170
+
171
+ // Try growing the pool if no free buffers
172
+ if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
173
+ cur_pool_size++;
174
+ wgpu::Buffer dev_buf;
175
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
176
+
177
+ if (!dev_buf) {
178
+ GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
179
+ }
180
+ return dev_buf;
181
+ }
182
+ cv.wait(lock, [this] { return !free.empty(); });
183
+ wgpu::Buffer buf = free.back();
184
+ free.pop_back();
185
+ return buf;
186
+ }
187
+
188
+ void free_bufs(std::vector<wgpu::Buffer> bufs) {
189
+ std::lock_guard<std::mutex> lock(mutex);
190
+ free.insert(free.end(), bufs.begin(), bufs.end());
191
+ cv.notify_all();
192
+ }
193
+
194
+ void cleanup() {
195
+ std::lock_guard<std::mutex> lock(mutex);
196
+ for (auto & buf : free) {
197
+ if (buf) {
198
+ buf.Destroy();
199
+ }
200
+ }
201
+ free.clear();
202
+ }
203
+
204
+ ~webgpu_buf_pool() { this->cleanup(); }
205
+ };
206
+
207
+ #ifdef GGML_WEBGPU_GPU_PROFILE
208
+ struct webgpu_gpu_profile_bufs {
209
+ wgpu::Buffer host_buf;
210
+ wgpu::Buffer dev_buf;
211
+ wgpu::QuerySet query_set;
63
212
  };
64
213
 
65
214
  // Holds a pool of parameter buffers for WebGPU operations
66
- struct webgpu_buf_pool {
67
- std::vector<webgpu_pool_bufs> free;
215
+ struct webgpu_gpu_profile_buf_pool {
216
+ std::vector<webgpu_gpu_profile_bufs> free;
68
217
 
69
218
  std::mutex mutex;
70
219
 
@@ -78,21 +227,28 @@ struct webgpu_buf_pool {
78
227
  for (int i = 0; i < num_bufs; i++) {
79
228
  wgpu::Buffer host_buf;
80
229
  wgpu::Buffer dev_buf;
81
- ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
82
- ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
83
- free.push_back({ host_buf, dev_buf });
230
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
231
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
232
+ // Create a query set for 2 timestamps
233
+ wgpu::QuerySetDescriptor ts_query_set_desc = {};
234
+
235
+ ts_query_set_desc.type = wgpu::QueryType::Timestamp;
236
+ ts_query_set_desc.count = 2;
237
+ wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
238
+
239
+ free.push_back({ host_buf, dev_buf, ts_query_set });
84
240
  }
85
241
  }
86
242
 
87
- webgpu_pool_bufs alloc_bufs() {
243
+ webgpu_gpu_profile_bufs alloc_bufs() {
88
244
  std::unique_lock<std::mutex> lock(mutex);
89
245
  cv.wait(lock, [this] { return !free.empty(); });
90
- webgpu_pool_bufs bufs = free.back();
246
+ webgpu_gpu_profile_bufs bufs = free.back();
91
247
  free.pop_back();
92
248
  return bufs;
93
249
  }
94
250
 
95
- void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
251
+ void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
96
252
  std::lock_guard<std::mutex> lock(mutex);
97
253
  free.insert(free.end(), bufs.begin(), bufs.end());
98
254
  cv.notify_all();
@@ -103,101 +259,163 @@ struct webgpu_buf_pool {
103
259
  for (auto & bufs : free) {
104
260
  bufs.host_buf.Destroy();
105
261
  bufs.dev_buf.Destroy();
262
+ bufs.query_set.Destroy();
106
263
  }
107
264
  free.clear();
108
265
  }
266
+
267
+ ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
109
268
  };
269
+ #endif
110
270
 
111
- // All the base objects needed to run operations on a WebGPU device
112
- struct webgpu_context_struct {
271
+ struct webgpu_command {
272
+ uint32_t num_kernels;
273
+ wgpu::CommandBuffer commands;
274
+ std::vector<wgpu::Buffer> params_bufs;
275
+ #ifdef GGML_WEBGPU_GPU_PROFILE
276
+ webgpu_gpu_profile_bufs timestamp_query_bufs;
277
+ std::string pipeline_name;
278
+ #endif
279
+ };
280
+
281
+ struct webgpu_capabilities {
282
+ wgpu::Limits limits;
283
+ bool supports_subgroup_matrix = false;
284
+
285
+ uint32_t sg_mat_m = 0;
286
+ uint32_t sg_mat_n = 0;
287
+ uint32_t sg_mat_k = 0;
288
+
289
+ uint32_t subgroup_size = 0;
290
+ uint32_t max_subgroup_size = 0;
291
+ size_t memset_bytes_per_thread;
292
+ };
293
+
294
+ // Stores global webgpu members
295
+ struct webgpu_global_context_struct {
113
296
  wgpu::Instance instance;
114
297
  wgpu::Adapter adapter;
115
298
  wgpu::Device device;
116
299
  wgpu::Queue queue;
117
- wgpu::Limits limits;
118
-
119
- // Separate this out from limits since on some Metal systems, the limit returned by
120
- // querying the limits is higher than the actual allowed maximum.
121
- uint32_t max_wg_size_x;
122
300
 
301
+ webgpu_capabilities capabilities;
302
+ // Shared buffer to move data from device to host
303
+ wgpu::Buffer get_tensor_staging_buf;
304
+ // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
123
305
  std::recursive_mutex mutex;
124
306
 
125
- webgpu_buf_pool param_buf_pool;
126
- webgpu_buf_pool set_rows_error_buf_pool;
127
-
128
- wgpu::ComputePipeline memset_pipeline;
129
- wgpu::ComputePipeline mul_mat_pipeline[30][2];
130
- wgpu::ComputePipeline set_rows_pipeline;
131
- wgpu::ComputePipeline get_rows_pipeline[30];
132
- wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
133
- wgpu::ComputePipeline cpy_pipeline;
134
- wgpu::ComputePipeline add_pipeline[2];
135
- wgpu::ComputePipeline add_ip_pipeline[2];
136
- wgpu::ComputePipeline mul_pipeline[2];
137
- wgpu::ComputePipeline mul_ip_pipeline[2];
138
- wgpu::ComputePipeline rms_norm_pipeline;
139
- wgpu::ComputePipeline rms_norm_ip_pipeline;
140
-
141
- size_t memset_bytes_per_thread;
142
-
143
- // Staging buffer for reading data from the GPU
144
- wgpu::Buffer get_tensor_staging_buf;
307
+ webgpu_buf_pool memset_buf_pool;
308
+ std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
145
309
 
146
- // Command buffers which need to be submitted
147
- std::vector<wgpu::CommandBuffer> staged_command_bufs;
148
-
149
- // Parameter buffers associated with the staged command buffers
150
- std::vector<webgpu_pool_bufs> staged_param_bufs;
151
- // Buffers associated with set_rows operations, used to store potential errors
152
- std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
310
+ #ifdef GGML_WEBGPU_CPU_PROFILE
311
+ // Profiling: labeled CPU time in ms (total)
312
+ std::unordered_map<std::string, double> cpu_time_ms;
313
+ // Profiling: detailed CPU time in ms
314
+ std::unordered_map<std::string, double> cpu_detail_ms;
315
+ #endif
153
316
 
154
- std::vector<wgpu::FutureWaitInfo> callback_futures;
317
+ #ifdef GGML_WEBGPU_GPU_PROFILE
318
+ // Profiling: per-shader GPU time in ms
319
+ std::unordered_map<std::string, double> shader_gpu_time_ms;
320
+ // Profiling: pool of timestamp query buffers (one per operation)
321
+ webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
322
+ #endif
155
323
 
156
324
  #ifdef GGML_WEBGPU_DEBUG
157
325
  wgpu::Buffer debug_host_buf;
158
326
  wgpu::Buffer debug_dev_buf;
159
327
  #endif
328
+
329
+ ~webgpu_global_context_struct() {
330
+ if (this->get_tensor_staging_buf) {
331
+ this->get_tensor_staging_buf.Destroy();
332
+ this->get_tensor_staging_buf = nullptr;
333
+ }
334
+ #ifdef GGML_WEBGPU_DEBUG
335
+ if (this->debug_host_buf) {
336
+ this->debug_host_buf.Destroy();
337
+ this->debug_host_buf = nullptr;
338
+ }
339
+ if (this->debug_dev_buf) {
340
+ this->debug_dev_buf.Destroy();
341
+ this->debug_dev_buf = nullptr;
342
+ }
343
+ #endif
344
+ }
345
+ };
346
+
347
+ typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
348
+
349
+ struct webgpu_submission {
350
+ wgpu::FutureWaitInfo submit_done;
351
+ #ifdef GGML_WEBGPU_GPU_PROFILE
352
+ std::vector<wgpu::FutureWaitInfo> profile_futures;
353
+ #endif
354
+ };
355
+
356
+ // All the base objects needed to run operations on a WebGPU device
357
+ struct webgpu_context_struct {
358
+ // Points to global instances owned by ggml_backend_webgpu_reg_context
359
+ webgpu_global_context global_ctx;
360
+
361
+ std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
362
+
363
+ webgpu_buf_pool param_buf_pool;
364
+ wgpu::Buffer set_rows_dev_error_buf;
365
+ wgpu::Buffer set_rows_host_error_buf;
366
+
367
+ std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
368
+
369
+ std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
370
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
371
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
372
+
373
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
374
+
375
+ size_t memset_bytes_per_thread;
160
376
  };
161
377
 
162
378
  typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
163
379
 
380
+ // Metadata required for the ggml backend registration/discovery interface
164
381
  struct ggml_backend_webgpu_reg_context {
165
- webgpu_context webgpu_ctx;
166
- size_t device_count;
167
- const char * name;
382
+ // Since the Instance is a global entrypoint into the WebGPU API, it lives here
383
+ webgpu_global_context webgpu_global_ctx;
384
+ size_t device_count;
385
+ const char * name;
168
386
  };
169
387
 
388
+ // Per-device struct for the global logical device interface
170
389
  struct ggml_backend_webgpu_device_context {
171
- webgpu_context webgpu_ctx;
172
- std::string device_name;
173
- std::string device_desc;
390
+ webgpu_global_context webgpu_global_ctx;
391
+ std::string device_name;
392
+ std::string device_desc;
174
393
  };
175
394
 
395
+ // Per-thread data required to actually run WebGPU operations in a backend instance
176
396
  struct ggml_backend_webgpu_context {
177
397
  webgpu_context webgpu_ctx;
178
398
  std::string name;
179
399
  };
180
400
 
401
+ // Per-thread data related to buffers
181
402
  struct ggml_backend_webgpu_buffer_context {
182
- webgpu_context webgpu_ctx;
183
- wgpu::Buffer buffer;
184
-
185
- ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
186
- webgpu_ctx(std::move(ctx)),
187
- buffer(std::move(buf)) {}
403
+ wgpu::Buffer buffer;
404
+ std::string label;
405
+ webgpu_global_context global_ctx;
406
+
407
+ ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
408
+ buffer(std::move(buf)),
409
+ label(std::move(lbl)),
410
+ global_ctx(std::move(global_ctx_)) {}
188
411
  };
189
412
 
190
- /* End struct definitions */
191
-
192
413
  /* WebGPU object initializations */
193
414
 
194
- static void ggml_webgpu_create_pipeline(wgpu::Device & device,
195
- wgpu::ComputePipeline & pipeline,
196
- const char * shader_code,
197
- const char * label,
198
- const std::vector<wgpu::ConstantEntry> & constants = {}) {
199
- WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
200
-
415
+ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
416
+ const char * shader_code,
417
+ const char * label,
418
+ const std::vector<wgpu::ConstantEntry> & constants = {}) {
201
419
  wgpu::ShaderSourceWGSL shader_source;
202
420
  shader_source.code = shader_code;
203
421
 
@@ -215,7 +433,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
215
433
  pipeline_desc.compute.constants = constants.data();
216
434
  pipeline_desc.compute.constantCount = constants.size();
217
435
  }
218
- pipeline = device.CreateComputePipeline(&pipeline_desc);
436
+ return { device.CreateComputePipeline(&pipeline_desc), label };
219
437
  }
220
438
 
221
439
  static void ggml_webgpu_create_buffer(wgpu::Device & device,
@@ -223,8 +441,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
223
441
  size_t size,
224
442
  wgpu::BufferUsage usage,
225
443
  const char * label) {
226
- WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
227
-
228
444
  wgpu::BufferDescriptor buffer_desc;
229
445
  buffer_desc.size = size;
230
446
  buffer_desc.usage = usage;
@@ -239,88 +455,113 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
239
455
 
240
456
  /** WebGPU Actions */
241
457
 
242
- // Wait for the queue to finish processing all submitted work
243
- static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
244
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
245
- if (ctx->callback_futures.empty()) {
246
- // no existing callbacks, wait on queue submission
247
- ctx->instance.WaitAny(
248
- ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
249
- [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
250
- if (status != wgpu::QueueWorkDoneStatus::Success) {
251
- GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
252
- std::string(message).c_str());
253
- }
254
- }),
255
- UINT64_MAX);
256
- } else {
257
- // existing callbacks, wait on them
258
- ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
259
- ctx->callback_futures.clear();
458
+ static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
459
+ switch (status) {
460
+ case wgpu::WaitStatus::Success:
461
+ return true;
462
+ case wgpu::WaitStatus::TimedOut:
463
+ if (allow_timeout) {
464
+ return false;
465
+ }
466
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
467
+ return false;
468
+ case wgpu::WaitStatus::Error:
469
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
470
+ return false;
471
+ default:
472
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
473
+ return false;
260
474
  }
261
475
  }
262
476
 
263
- static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
264
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
265
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
266
- if (ctx->staged_command_bufs.empty()) {
267
- // Nothing to submit
477
+ #ifdef GGML_WEBGPU_GPU_PROFILE
478
+ static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
479
+ futures.erase(std::remove_if(futures.begin(), futures.end(),
480
+ [](const wgpu::FutureWaitInfo & info) { return info.completed; }),
481
+ futures.end());
482
+ }
483
+
484
+ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
485
+ std::vector<wgpu::FutureWaitInfo> & futures,
486
+ bool block) {
487
+ if (futures.empty()) {
268
488
  return;
269
489
  }
270
- ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
271
490
 
272
- // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
273
- if (ctx->staged_set_row_error_bufs.size() > 0) {
274
- wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
275
- for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
276
- // Copy the error buffer to the host buffer
277
- encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
491
+ uint64_t timeout_ms = block ? UINT64_MAX : 0;
492
+ if (block) {
493
+ while (!futures.empty()) {
494
+ auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
495
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
496
+ ggml_backend_webgpu_erase_completed_futures(futures);
497
+ }
498
+ }
499
+ } else {
500
+ auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
501
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
502
+ ggml_backend_webgpu_erase_completed_futures(futures);
278
503
  }
279
- wgpu::CommandBuffer commands = encoder.Finish();
280
- ctx->queue.Submit(1, &commands);
281
504
  }
505
+ }
506
+ #endif
282
507
 
283
- ctx->staged_command_bufs.clear();
284
- std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
285
- std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);
508
+ // Wait for the queue to finish processing all submitted work
509
+ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
510
+ std::vector<webgpu_submission> & subs,
511
+ bool block = true) {
512
+ // If we have too many in-flight submissions, wait on the oldest one first.
513
+ if (subs.empty()) {
514
+ return;
515
+ }
516
+ while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
517
+ auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
518
+ if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
519
+ #ifdef GGML_WEBGPU_GPU_PROFILE
520
+ ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
521
+ #endif
522
+ subs.erase(subs.begin());
523
+ }
524
+ }
286
525
 
287
- // Free the staged parameter buffers once the submission completes
288
- wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
289
- wgpu::CallbackMode::AllowSpontaneous,
290
- [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
291
- if (status != wgpu::QueueWorkDoneStatus::Success) {
292
- GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
293
- }
294
- // Free the staged buffers
295
- ctx->param_buf_pool.free_bufs(staged_param_bufs);
296
- });
297
- ctx->callback_futures.push_back({ p_f });
526
+ if (subs.empty()) {
527
+ return;
528
+ }
298
529
 
299
- // Check for errrors in SET_ROWS operations
300
- for (auto & error_bufs : staged_set_row_error_bufs) {
301
- wgpu::Future f = error_bufs.host_buf.MapAsync(
302
- wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
303
- [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
304
- if (status != wgpu::MapAsyncStatus::Success) {
305
- GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
306
- } else {
307
- const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
308
- if (*error_data) {
309
- GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
310
- }
311
- // We can't unmap in here due to WebGPU reentrancy limitations.
312
- ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
313
- }
314
- });
315
- ctx->callback_futures.push_back({ f });
530
+ if (block) {
531
+ for (auto & sub : subs) {
532
+ while (!sub.submit_done.completed) {
533
+ auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
534
+ ggml_backend_webgpu_handle_wait_status(waitStatus);
535
+ }
536
+ #ifdef GGML_WEBGPU_GPU_PROFILE
537
+ ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
538
+ #endif
539
+ }
540
+ subs.clear();
541
+ } else {
542
+ // Poll each submit future once and remove completed submissions.
543
+ for (auto sub = subs.begin(); sub != subs.end();) {
544
+ auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
545
+ ggml_backend_webgpu_handle_wait_status(waitStatus, true);
546
+ #ifdef GGML_WEBGPU_GPU_PROFILE
547
+ ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
548
+ if (sub->submit_done.completed && sub->profile_futures.empty()) {
549
+ #else
550
+ if (sub->submit_done.completed) {
551
+ #endif
552
+ sub = subs.erase(sub);
553
+ } else {
554
+ ++sub;
555
+ }
556
+ }
316
557
  }
317
558
  }
318
559
 
319
- static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
320
- wgpu::Buffer & buffer,
321
- wgpu::MapMode mode,
322
- size_t offset,
323
- size_t size) {
560
+ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
561
+ wgpu::Buffer & buffer,
562
+ wgpu::MapMode mode,
563
+ size_t offset,
564
+ size_t size) {
324
565
  ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
325
566
  [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
326
567
  if (status != wgpu::MapAsyncStatus::Success) {
@@ -335,100 +576,178 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
335
576
  // This function adds debugging information to shaders, as WebGPU does not support printing directly.
336
577
  // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
337
578
  // debug statements in the shader, and then call this function after encoding the commands and submitting them.
338
- static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
339
- ggml_backend_webgpu_submit_queue(ctx);
579
+ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
340
580
  wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
341
581
  encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
342
582
  wgpu::CommandBuffer commands = encoder.Finish();
343
583
  ctx->queue.Submit(1, &commands);
344
-
345
584
  ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
346
- const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
347
- std::cout << "debug data:";
348
- for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
349
- std::cout << " " << i << ": " << debug_data[i];
350
- }
351
- std::cout << "\n";
585
+ const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
586
+ std::cout << "debug[0]: " << debug_data[0] << "\n";
352
587
  ctx->debug_host_buf.Unmap();
353
588
  }
354
589
  #endif
355
590
 
356
- static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
357
- wgpu::ComputePipeline & pipeline,
358
- std::vector<uint32_t> params,
359
- std::vector<wgpu::BindGroupEntry> bind_group_entries,
360
- uint32_t wg_x,
361
- const char * bind_group_label = nullptr,
362
- bool submit_and_wait = false) {
363
- webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
364
-
365
- ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
366
- uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
367
- for (size_t i = 0; i < params.size(); i++) {
368
- _params[i] = params[i];
369
- };
591
+ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
592
+ std::vector<webgpu_command> & commands,
593
+ webgpu_buf_pool & param_buf_pool) {
594
+ std::vector<wgpu::CommandBuffer> command_buffers;
595
+ std::vector<wgpu::Buffer> params_bufs;
596
+ webgpu_submission submission;
597
+ #ifdef GGML_WEBGPU_GPU_PROFILE
598
+ std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
599
+ #endif
600
+
601
+ for (const auto & command : commands) {
602
+ command_buffers.push_back(command.commands);
603
+ params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
604
+ }
605
+ ctx->queue.Submit(command_buffers.size(), command_buffers.data());
606
+
607
+ wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
608
+ wgpu::CallbackMode::AllowSpontaneous,
609
+ [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
610
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
611
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
612
+ }
613
+ // Free the staged buffers
614
+ param_buf_pool.free_bufs(params_bufs);
615
+ });
616
+ submission.submit_done = { p_f };
370
617
 
371
- params_bufs.host_buf.Unmap();
618
+ #ifdef GGML_WEBGPU_GPU_PROFILE
619
+ for (const auto & command : commands) {
620
+ auto label = command.pipeline_name;
621
+ auto ts_bufs = command.timestamp_query_bufs;
372
622
 
373
- uint32_t params_bufs_binding_num = bind_group_entries.size();
374
- bind_group_entries.push_back({ .binding = params_bufs_binding_num,
375
- .buffer = params_bufs.dev_buf,
376
- .offset = 0,
377
- .size = params_bufs.dev_buf.GetSize() });
623
+ wgpu::Future f = ts_bufs.host_buf.MapAsync(
624
+ wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
625
+ [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
626
+ if (status != wgpu::MapAsyncStatus::Success) {
627
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
628
+ } else {
629
+ const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
630
+ // WebGPU timestamps are in ns; convert to ms
631
+ double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
632
+ ctx->shader_gpu_time_ms[label] += elapsed_ms;
633
+ }
634
+ // We can't unmap in here due to WebGPU reentrancy limitations.
635
+ ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
636
+ });
637
+ submission.profile_futures.push_back({ f });
638
+ }
639
+ #endif
640
+ return submission;
641
+ }
378
642
 
379
- wgpu::BindGroupDescriptor bind_group_desc;
380
- bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
381
- bind_group_desc.entryCount = bind_group_entries.size();
382
- bind_group_desc.entries = bind_group_entries.data();
383
- if (bind_group_label) {
384
- bind_group_desc.label = bind_group_label;
643
+ static webgpu_command ggml_backend_webgpu_build_multi(
644
+ webgpu_global_context & ctx,
645
+ webgpu_buf_pool & param_buf_pool,
646
+ const std::vector<webgpu_pipeline> & pipelines,
647
+ const std::vector<std::vector<uint32_t>> & params_list,
648
+ const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
649
+ const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
650
+ GGML_ASSERT(pipelines.size() == params_list.size());
651
+ GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
652
+ GGML_ASSERT(pipelines.size() == workgroups_list.size());
653
+
654
+ std::vector<wgpu::Buffer> params_bufs_list;
655
+ std::vector<wgpu::BindGroup> bind_groups;
656
+
657
+ for (size_t i = 0; i < pipelines.size(); i++) {
658
+ wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
659
+
660
+ std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
661
+ uint32_t params_binding_num = entries.size();
662
+ entries.push_back(
663
+ { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
664
+
665
+ wgpu::BindGroupDescriptor bind_group_desc;
666
+ bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
667
+ bind_group_desc.entryCount = entries.size();
668
+ bind_group_desc.entries = entries.data();
669
+ bind_group_desc.label = pipelines[i].name.c_str();
670
+ bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
671
+
672
+ params_bufs_list.push_back(params_bufs);
385
673
  }
386
- wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
387
674
 
388
675
  wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
389
- encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
676
+ for (size_t i = 0; i < params_bufs_list.size(); i++) {
677
+ ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
678
+ }
679
+
680
+ #ifdef GGML_WEBGPU_GPU_PROFILE
681
+ webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
682
+ if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
683
+ ts_bufs.host_buf.Unmap();
684
+ }
685
+
686
+ wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
687
+ .beginningOfPassWriteIndex = 0,
688
+ .endOfPassWriteIndex = 1 };
689
+ wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
690
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
691
+ #else
390
692
  wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
391
- pass.SetPipeline(pipeline);
392
- pass.SetBindGroup(0, bind_group);
393
- pass.DispatchWorkgroups(wg_x, 1, 1);
693
+ #endif
694
+ for (size_t i = 0; i < pipelines.size(); i++) {
695
+ pass.SetPipeline(pipelines[i].pipeline);
696
+ pass.SetBindGroup(0, bind_groups[i]);
697
+ pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
698
+ }
394
699
  pass.End();
700
+
701
+ #ifdef GGML_WEBGPU_GPU_PROFILE
702
+ encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
703
+ encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
704
+ #endif
705
+
395
706
  wgpu::CommandBuffer commands = encoder.Finish();
396
- if (submit_and_wait) {
397
- // Submit and wait immediately
398
- ctx->queue.Submit(1, &commands);
399
- ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
400
- wgpu::CallbackMode::AllowSpontaneous,
401
- [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
402
- if (status != wgpu::QueueWorkDoneStatus::Success) {
403
- GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
404
- }
405
- ctx->param_buf_pool.free_bufs({ params_bufs });
406
- }),
407
- UINT64_MAX);
408
- } else {
409
- // Lock the context mutex when pushing to the staging vectors.
410
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
411
- // Enqueue commands and only submit if we have enough staged commands
412
- ctx->staged_command_bufs.push_back(commands);
413
- ctx->staged_param_bufs.push_back(params_bufs);
414
- if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
415
- ggml_backend_webgpu_submit_queue(ctx);
416
- }
417
- }
707
+ webgpu_command result = {};
708
+ result.commands = commands;
709
+ result.params_bufs = params_bufs_list;
710
+ result.num_kernels = pipelines.size();
711
+ #ifdef GGML_WEBGPU_GPU_PROFILE
712
+ result.timestamp_query_bufs = ts_bufs;
713
+ // TODO: handle multiple pipeline names
714
+ result.pipeline_name = pipelines.front().name;
715
+ #endif
716
+ return result;
717
+ }
718
+
719
+ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx,
720
+ webgpu_buf_pool & param_buf_pool,
721
+ webgpu_pipeline & pipeline,
722
+ std::vector<uint32_t> params,
723
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
724
+ uint32_t wg_x,
725
+ uint32_t wg_y = 1) {
726
+ return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
727
+ {
728
+ pipeline
729
+ },
730
+ { std::move(params) }, { std::move(bind_group_entries) },
731
+ { { wg_x, wg_y } });
418
732
  }
419
733
 
420
- static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
421
- wgpu::Buffer & buf,
422
- uint32_t value,
423
- size_t offset,
424
- size_t size) {
734
+ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
735
+ wgpu::Buffer & buf,
736
+ uint32_t value,
737
+ size_t offset,
738
+ size_t size) {
425
739
  std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
426
740
  std::vector<wgpu::BindGroupEntry> entries = {
427
741
  { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
428
742
  };
429
- size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread;
430
- uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
431
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
743
+ size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
744
+ uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
745
+
746
+ webgpu_command command =
747
+ ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
748
+ std::vector<webgpu_command> commands = { command };
749
+ std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
750
+ ggml_backend_webgpu_wait(ctx, sub);
432
751
  }
433
752
 
434
753
  /** End WebGPU Actions */
@@ -444,8 +763,48 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
444
763
  ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
445
764
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
446
765
 
447
- // TODO: cleanup
448
- GGML_UNUSED(ctx);
766
+ #ifdef GGML_WEBGPU_CPU_PROFILE
767
+ std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
768
+ double total_cpu = 0.0;
769
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
770
+ total_cpu += kv.second;
771
+ }
772
+ std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
773
+ std::cout << "ggml_webgpu: cpu breakdown:\n";
774
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
775
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
776
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
777
+ }
778
+ if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
779
+ std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
780
+ }
781
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
782
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
783
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
784
+ }
785
+ #endif
786
+
787
+ #ifdef GGML_WEBGPU_GPU_PROFILE
788
+ std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
789
+ double total_gpu = 0.0;
790
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
791
+ total_gpu += kv.second;
792
+ }
793
+ std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
794
+ std::cout << "\nggml_webgpu: gpu breakdown:\n";
795
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
796
+ double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
797
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
798
+ << pct << "%)\n";
799
+ }
800
+ #endif
801
+
802
+ #if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
803
+ std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
804
+ #endif
805
+
806
+ delete ctx;
807
+ delete backend;
449
808
  }
450
809
 
451
810
  static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
@@ -457,19 +816,18 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
457
816
  return ctx->buffer;
458
817
  }
459
818
 
460
- static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
819
+ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
461
820
  size_t offset = ggml_webgpu_tensor_offset(t);
462
- return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
821
+ return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
463
822
  }
464
823
 
465
- static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
824
+ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
466
825
  size_t offset = ggml_webgpu_tensor_offset(t);
467
- return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
826
+ return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
468
827
  }
469
828
 
470
829
  static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
471
- return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
472
- ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
830
+ return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
473
831
  }
474
832
 
475
833
  // Used to determine if two tensors are the same for in-place operations
@@ -478,7 +836,31 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
478
836
  (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
479
837
  }
480
838
 
481
- static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
839
+ // Used to determine if two tensors share the same buffer and their byte ranges overlap,
840
+ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
841
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
842
+ ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
843
+ ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
844
+ }
845
+
846
+ struct binary_overlap_flags {
847
+ bool inplace; // src0 == dst
848
+ bool overlap; // src1 == dst
849
+ bool src_overlap;
850
+ };
851
+
852
+ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
853
+ ggml_tensor * src1,
854
+ ggml_tensor * dst) {
855
+ binary_overlap_flags flags = {};
856
+ flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
857
+ flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
858
+ flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
859
+
860
+ return flags;
861
+ }
862
+
863
+ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
482
864
  uint32_t ne = (uint32_t) ggml_nelements(dst);
483
865
 
484
866
  std::vector<uint32_t> params = {
@@ -489,8 +871,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
489
871
  (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
490
872
  (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
491
873
  (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
492
- // Logical shape — same for both tensors even if permuted
493
- (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
874
+ // Logical shapes
875
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
876
+ (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
494
877
  };
495
878
 
496
879
  std::vector<wgpu::BindGroupEntry> entries = {
@@ -504,36 +887,49 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
504
887
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
505
888
  };
506
889
 
507
- size_t max_wg_size = ctx->max_wg_size_x;
508
- uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
509
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
890
+ uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
891
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
892
+ params, entries, wg_x);
510
893
  }
511
894
 
512
- static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
513
- // For set rows specifically, we need to check if src and idx are empty tensors.
514
- if (ggml_is_empty(src) || ggml_is_empty(idx)) {
515
- return;
516
- }
895
+ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
896
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
897
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
898
+ };
517
899
 
518
- webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
519
- if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
520
- error_bufs.host_buf.Unmap();
521
- }
900
+ webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
901
+
902
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
903
+
904
+ const uint32_t ne = (uint32_t) ggml_nelements(dst);
522
905
 
523
906
  std::vector<uint32_t> params = {
907
+ ne,
524
908
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
525
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
526
909
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
527
- // Convert byte-strides to element-strides
528
- (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
529
- (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
530
- (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
531
- (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
532
- (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
533
- // Shape of src
534
- (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
535
- // Shape of idx
536
- (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
910
+ // Strides (in elements)
911
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
912
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
913
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
914
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
915
+ // Shapes
916
+ (uint32_t) src->ne[0],
917
+ (uint32_t) src->ne[1],
918
+ (uint32_t) src->ne[2],
919
+ (uint32_t) src->ne[3],
920
+ (uint32_t) dst->ne[0],
921
+ (uint32_t) dst->ne[1],
922
+ (uint32_t) dst->ne[2],
923
+ (uint32_t) dst->ne[3],
924
+ // Pad sizes
925
+ (uint32_t) ggml_get_op_params_i32(dst, 0),
926
+ (uint32_t) ggml_get_op_params_i32(dst, 1),
927
+ (uint32_t) ggml_get_op_params_i32(dst, 2),
928
+ (uint32_t) ggml_get_op_params_i32(dst, 3),
929
+ (uint32_t) ggml_get_op_params_i32(dst, 4),
930
+ (uint32_t) ggml_get_op_params_i32(dst, 5),
931
+ (uint32_t) ggml_get_op_params_i32(dst, 6),
932
+ (uint32_t) ggml_get_op_params_i32(dst, 7),
537
933
  };
538
934
 
539
935
  std::vector<wgpu::BindGroupEntry> entries = {
@@ -542,26 +938,36 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
542
938
  .offset = ggml_webgpu_tensor_align_offset(ctx, src),
543
939
  .size = ggml_webgpu_tensor_binding_size(ctx, src) },
544
940
  { .binding = 1,
545
- .buffer = ggml_webgpu_tensor_buf(idx),
546
- .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
547
- .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
548
- { .binding = 2,
549
941
  .buffer = ggml_webgpu_tensor_buf(dst),
550
942
  .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
551
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
552
- { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
943
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
553
944
  };
554
945
 
555
- size_t max_wg_size = ctx->max_wg_size_x;
556
- uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
946
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
947
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
948
+ }
557
949
 
558
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
559
- ctx->staged_set_row_error_bufs.push_back(error_bufs);
950
+ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
951
+ ggml_tensor * src,
952
+ ggml_tensor * idx,
953
+ ggml_tensor * dst) {
954
+ // For set rows specifically, we need to check if src and idx are empty
955
+ // tensors.
956
+ if (ggml_is_empty(src) || ggml_is_empty(idx)) {
957
+ return std::nullopt;
958
+ }
560
959
 
561
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
562
- }
960
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
961
+ .src0 = src,
962
+ .src1 = idx,
963
+ .dst = dst,
964
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
965
+ };
966
+
967
+ webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
968
+
969
+ auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
563
970
 
564
- static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
565
971
  std::vector<uint32_t> params = {
566
972
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
567
973
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
@@ -572,8 +978,8 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
572
978
  (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
573
979
  (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
574
980
  (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
575
- // Shape of dst
576
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
981
+ // Shape of src
982
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
577
983
  // Shape of idx
578
984
  (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
579
985
  };
@@ -593,43 +999,177 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
593
999
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
594
1000
  };
595
1001
 
596
- size_t max_wg_size = ctx->max_wg_size_x;
597
- uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size;
1002
+ if (decisions->i64_idx) {
1003
+ entries.push_back({ .binding = 3,
1004
+ .buffer = ctx->set_rows_dev_error_buf,
1005
+ .offset = 0,
1006
+ .size = ctx->set_rows_dev_error_buf.GetSize() });
1007
+ }
598
1008
 
599
- wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type];
600
- if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) {
601
- pipeline = ctx->get_rows_f32_no_vec_pipeline;
1009
+ uint32_t threads;
1010
+ if (decisions->vec4) {
1011
+ threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
1012
+ } else {
1013
+ threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
602
1014
  }
603
- ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
1015
+ uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
1016
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
604
1017
  }
605
1018
 
606
- static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
607
- std::vector<uint32_t> params = {
608
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
609
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
610
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
611
- (uint32_t) dst->ne[1], // number of rows in result (M)
612
- (uint32_t) dst->ne[0], // number of columns in result (N)
613
- (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
614
- (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
615
- (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
616
- (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
617
- (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
618
- (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
619
- (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
620
- (uint32_t) src0->ne[2], // batch size in dimension 2
621
- (uint32_t) src0->ne[3], // batch size in dimension 3
622
- (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
623
- (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
1019
+ // Workgroup size is a common constant
1020
+ static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
1021
+ std::vector<wgpu::ConstantEntry> constants(1);
1022
+ constants[0].key = "wg_size";
1023
+ constants[0].value = wg_size;
1024
+ return constants;
1025
+ }
1026
+
1027
+ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
1028
+ ggml_tensor * src,
1029
+ ggml_tensor * idx,
1030
+ ggml_tensor * dst) {
1031
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1032
+ .src0 = src,
1033
+ .src1 = nullptr,
1034
+ .dst = dst,
1035
+ .max_wg_size = WEBGPU_MAX_WG_SIZE,
624
1036
  };
625
1037
 
1038
+ webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
1039
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1040
+
1041
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1042
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
1043
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1044
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1045
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1046
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1047
+ (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
1048
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
1049
+ (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
1050
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1051
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1052
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1053
+ (uint32_t) dst->ne[0],
1054
+ (uint32_t) dst->ne[1],
1055
+ (uint32_t) dst->ne[2],
1056
+ (uint32_t) dst->ne[3],
1057
+ (uint32_t) (idx->ne[1]),
1058
+ (uint32_t) (idx->ne[2]) };
1059
+
626
1060
  std::vector<wgpu::BindGroupEntry> entries = {
627
1061
  { .binding = 0,
628
- .buffer = ggml_webgpu_tensor_buf(src0),
629
- .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
630
- .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1062
+ .buffer = ggml_webgpu_tensor_buf(src),
1063
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1064
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
631
1065
  { .binding = 1,
632
- .buffer = ggml_webgpu_tensor_buf(src1),
1066
+ .buffer = ggml_webgpu_tensor_buf(idx),
1067
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
1068
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
1069
+ { .binding = 2,
1070
+ .buffer = ggml_webgpu_tensor_buf(dst),
1071
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1072
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1073
+ };
1074
+
1075
+ uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
1076
+
1077
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1078
+ }
1079
+
1080
+ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
1081
+ ggml_tensor * src0,
1082
+ ggml_tensor * src1,
1083
+ ggml_tensor * dst) {
1084
+ // Determine if this is a mat-vec operation
1085
+ bool is_vec = (dst->ne[1] == 1);
1086
+
1087
+ // Determine if we should use fast path
1088
+ bool use_fast = false;
1089
+ switch (src1->type) {
1090
+ case GGML_TYPE_F16:
1091
+ use_fast = (src0->type == GGML_TYPE_F16);
1092
+ break;
1093
+ case GGML_TYPE_F32:
1094
+ // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
1095
+ switch (src0->type) {
1096
+ case GGML_TYPE_F32:
1097
+ case GGML_TYPE_F16:
1098
+ case GGML_TYPE_Q4_0:
1099
+ case GGML_TYPE_Q4_1:
1100
+ case GGML_TYPE_Q5_0:
1101
+ case GGML_TYPE_Q5_1:
1102
+ case GGML_TYPE_Q8_0:
1103
+ case GGML_TYPE_Q8_1:
1104
+ case GGML_TYPE_Q6_K:
1105
+ use_fast = true;
1106
+ break;
1107
+ case GGML_TYPE_Q2_K:
1108
+ case GGML_TYPE_Q3_K:
1109
+ case GGML_TYPE_Q4_K:
1110
+ case GGML_TYPE_Q5_K:
1111
+ // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
1112
+ use_fast = !is_vec;
1113
+ break;
1114
+ default:
1115
+ break;
1116
+ }
1117
+ break;
1118
+ default:
1119
+ break;
1120
+ }
1121
+
1122
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1123
+ .src0 = src0,
1124
+ .src1 = src1,
1125
+ .dst = dst,
1126
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1127
+ .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
1128
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1129
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1130
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1131
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
1132
+ };
1133
+
1134
+ // Get or create pipeline
1135
+ webgpu_pipeline pipeline;
1136
+
1137
+ if (use_fast && is_vec) {
1138
+ pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
1139
+ } else if (use_fast) {
1140
+ pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
1141
+ } else {
1142
+ pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
1143
+ }
1144
+
1145
+ // Build params
1146
+ std::vector<uint32_t> params = {
1147
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1148
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1149
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1150
+ (uint32_t) dst->ne[0],
1151
+ (uint32_t) dst->ne[1],
1152
+ (uint32_t) src0->ne[0],
1153
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1154
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1155
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1156
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1157
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1158
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1159
+ (uint32_t) src0->ne[2],
1160
+ (uint32_t) src0->ne[3],
1161
+ (uint32_t) (src1->ne[2] / src0->ne[2]),
1162
+ (uint32_t) (src1->ne[3] / src0->ne[3])
1163
+ };
1164
+
1165
+ // Build bind group entries
1166
+ std::vector<wgpu::BindGroupEntry> entries = {
1167
+ { .binding = 0,
1168
+ .buffer = ggml_webgpu_tensor_buf(src0),
1169
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1170
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1171
+ { .binding = 1,
1172
+ .buffer = ggml_webgpu_tensor_buf(src1),
633
1173
  .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
634
1174
  .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
635
1175
  { .binding = 2,
@@ -638,23 +1178,281 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
638
1178
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
639
1179
  };
640
1180
 
641
- uint32_t wg_x =
642
- (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
643
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
644
- ggml_op_name(dst->op));
1181
+ // Calculate workgroup dimensions
1182
+ uint32_t wg_x = 1;
1183
+ uint32_t wg_y = 1;
1184
+ const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1185
+
1186
+ if (use_fast && is_vec) {
1187
+ auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
1188
+
1189
+ uint32_t batches = dst->ne[2] * dst->ne[3];
1190
+ uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
1191
+ uint32_t total_wg = output_groups * batches;
1192
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1193
+ } else if (use_fast) {
1194
+ auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
1195
+
1196
+ // Fast-path tiled/subgroup calculations
1197
+ uint32_t wg_m;
1198
+ uint32_t wg_n;
1199
+ if (decisions->use_subgroup_matrix) {
1200
+ uint32_t wg_m_sg_tile =
1201
+ decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
1202
+ wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1203
+ uint32_t wg_n_sg_tile =
1204
+ decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
1205
+ wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1206
+ } else {
1207
+ uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
1208
+ uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
1209
+ wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1210
+ wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
1211
+ }
1212
+ uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1213
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1214
+
1215
+ } else { // legacy
1216
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1217
+ uint32_t wg_size = decisions->wg_size;
1218
+ uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
1219
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1220
+ }
1221
+
1222
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
645
1223
  }
646
1224
 
647
- static void ggml_webgpu_binary_op(webgpu_context & ctx,
648
- ggml_tensor * src0,
649
- ggml_tensor * src1,
650
- ggml_tensor * dst,
651
- wgpu::ComputePipeline & pipeline,
652
- bool in_place) {
1225
+ #ifndef __EMSCRIPTEN__
1226
+ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1227
+ ggml_tensor * Q,
1228
+ ggml_tensor * K,
1229
+ ggml_tensor * V,
1230
+ ggml_tensor * mask,
1231
+ ggml_tensor * sinks,
1232
+ ggml_tensor * dst) {
1233
+ float scale = *(float *) dst->op_params;
1234
+ float max_bias;
1235
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1236
+ float logit_softcap;
1237
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
1238
+ if (logit_softcap != 0.0f) {
1239
+ scale /= logit_softcap;
1240
+ }
1241
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
1242
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1243
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1244
+
1245
+ const int has_mask = (mask != nullptr);
1246
+ const int has_sinks = (sinks != nullptr);
1247
+
653
1248
  std::vector<uint32_t> params = {
654
- (uint32_t) ggml_nelements(dst),
1249
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
1250
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
1251
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
1252
+ has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
1253
+ has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
1254
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1255
+ (uint32_t) Q->ne[2], // number of heads
1256
+ (uint32_t) Q->ne[1], // sequence length (Q)
1257
+ (uint32_t) K->ne[1], // sequence length (K/V)
1258
+ (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
1259
+ (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
1260
+ (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
1261
+ (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
1262
+ (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
1263
+ (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
1264
+ (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
1265
+ (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
1266
+ (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
1267
+ has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
1268
+ (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
1269
+ *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
1270
+ *(uint32_t *) &max_bias,
1271
+ *(uint32_t *) &logit_softcap,
1272
+ *(uint32_t *) &n_head_log2,
1273
+ *(uint32_t *) &m0,
1274
+ *(uint32_t *) &m1
1275
+
1276
+ };
1277
+ std::vector<wgpu::BindGroupEntry> entries = {
1278
+ { .binding = 0,
1279
+ .buffer = ggml_webgpu_tensor_buf(Q),
1280
+ .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
1281
+ .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
1282
+ { .binding = 1,
1283
+ .buffer = ggml_webgpu_tensor_buf(K),
1284
+ .offset = ggml_webgpu_tensor_align_offset(ctx, K),
1285
+ .size = ggml_webgpu_tensor_binding_size(ctx, K) },
1286
+ { .binding = 2,
1287
+ .buffer = ggml_webgpu_tensor_buf(V),
1288
+ .offset = ggml_webgpu_tensor_align_offset(ctx, V),
1289
+ .size = ggml_webgpu_tensor_binding_size(ctx, V) }
1290
+ };
1291
+ uint32_t binding_index = 3;
1292
+ if (has_mask) {
1293
+ entries.push_back({ .binding = binding_index++,
1294
+ .buffer = ggml_webgpu_tensor_buf(mask),
1295
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
1296
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
1297
+ }
1298
+ if (has_sinks) {
1299
+ entries.push_back({ .binding = binding_index++,
1300
+ .buffer = ggml_webgpu_tensor_buf(sinks),
1301
+ .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
1302
+ .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
1303
+ }
1304
+ entries.push_back({ .binding = binding_index++,
1305
+ .buffer = ggml_webgpu_tensor_buf(dst),
1306
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1307
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1308
+
1309
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1310
+ .src0 = Q,
1311
+ .src1 = K,
1312
+ .src2 = V,
1313
+ .src3 = mask,
1314
+ .src4 = sinks,
1315
+ .dst = dst,
1316
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1317
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1318
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1319
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1320
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1321
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
1322
+ };
1323
+
1324
+ webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
1325
+
1326
+ auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
1327
+
1328
+ uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
1329
+ uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
1330
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1331
+ }
1332
+ #endif
1333
+
1334
+ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1335
+ bool is_unary = dst->op == GGML_OP_UNARY;
1336
+ bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
1337
+
1338
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1339
+ .src0 = src,
1340
+ .src1 = nullptr,
1341
+ .dst = dst,
1342
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1343
+ .inplace = inplace,
1344
+ };
1345
+
1346
+ webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
1347
+
1348
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1349
+
1350
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1351
+
1352
+ std::vector<uint32_t> params = { ne,
1353
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1354
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1355
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
1356
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1357
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1358
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1359
+ (uint32_t) src->ne[0],
1360
+ (uint32_t) src->ne[1],
1361
+ (uint32_t) src->ne[2] };
1362
+
1363
+ ggml_tensor * effective_src = src;
1364
+ if (is_unary) {
1365
+ ggml_unary_op unary_op = ggml_get_unary_op(dst);
1366
+ switch (unary_op) {
1367
+ case GGML_UNARY_OP_XIELU:
1368
+ {
1369
+ // Get float parameters and reinterpret their bit patterns as uint32_t
1370
+ // for passing through the params buffer
1371
+ float alpha_n = ggml_get_op_params_f32(dst, 1);
1372
+ float alpha_p = ggml_get_op_params_f32(dst, 2);
1373
+ float beta = ggml_get_op_params_f32(dst, 3);
1374
+ float eps = ggml_get_op_params_f32(dst, 4);
1375
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1376
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1377
+ params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1378
+ params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1379
+ break;
1380
+ }
1381
+ default:
1382
+ break;
1383
+ }
1384
+ } else if (dst->op == GGML_OP_CLAMP) {
1385
+ float clamp_min = ggml_get_op_params_f32(dst, 0);
1386
+ float clamp_max = ggml_get_op_params_f32(dst, 1);
1387
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
1388
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
1389
+ } else if (dst->op == GGML_OP_FILL) {
1390
+ float fill_val = ggml_get_op_params_f32(dst, 0);
1391
+ params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
1392
+ effective_src = dst; // fill simply fills dst
1393
+ }
1394
+
1395
+ std::vector<wgpu::BindGroupEntry> entries = {
1396
+ { .binding = 0,
1397
+ .buffer = ggml_webgpu_tensor_buf(effective_src),
1398
+ .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src),
1399
+ .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
1400
+ };
1401
+ if (!inplace) {
1402
+ entries.push_back({ .binding = 1,
1403
+ .buffer = ggml_webgpu_tensor_buf(dst),
1404
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1405
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1406
+ }
1407
+
1408
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1409
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1410
+ }
1411
+
1412
+ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1413
+ ggml_tensor * src0,
1414
+ ggml_tensor * src1,
1415
+ ggml_tensor * dst) {
1416
+ binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
1417
+
1418
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1419
+ .src0 = src0,
1420
+ .src1 = src1,
1421
+ .dst = dst,
1422
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1423
+ .inplace = flags.inplace,
1424
+ .overlap = flags.overlap,
1425
+ .src_overlap = flags.src_overlap,
1426
+ };
1427
+
1428
+ webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
1429
+
1430
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1431
+
1432
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1433
+
1434
+ size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
1435
+ size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
1436
+
1437
+ uint32_t offset_merged_src0 = 0;
1438
+ uint32_t offset_merged_src1 = 0;
1439
+ if (flags.src_overlap) {
1440
+ size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
1441
+ offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
1442
+ offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
1443
+ }
1444
+
1445
+ std::vector<uint32_t> params = {
1446
+ ne,
655
1447
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
656
1448
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
657
1449
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1450
+ offset_merged_src0,
1451
+ offset_merged_src1,
1452
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
1453
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1454
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1455
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
658
1456
  (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
659
1457
  (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
660
1458
  (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
@@ -668,87 +1466,709 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
668
1466
  (uint32_t) src1->ne[3],
669
1467
  };
670
1468
 
671
- std::vector<wgpu::BindGroupEntry> entries = {
672
- { .binding = 0,
673
- .buffer = ggml_webgpu_tensor_buf(src0),
674
- .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
675
- .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
676
- { .binding = 1,
677
- .buffer = ggml_webgpu_tensor_buf(src1),
678
- .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
679
- .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
680
- };
681
- if (!in_place) {
682
- entries.push_back({ .binding = 2,
683
- .buffer = ggml_webgpu_tensor_buf(dst),
684
- .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
685
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1469
+ std::vector<wgpu::BindGroupEntry> entries;
1470
+
1471
+ if (flags.src_overlap) {
1472
+ size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
1473
+ size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
1474
+ src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
1475
+ entries.push_back({
1476
+ .binding = 0,
1477
+ .buffer = ggml_webgpu_tensor_buf(src0),
1478
+ .offset = merged_offset,
1479
+ .size = merged_end - merged_offset,
1480
+ });
1481
+ entries.push_back({
1482
+ .binding = 1,
1483
+ .buffer = ggml_webgpu_tensor_buf(dst),
1484
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1485
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst),
1486
+ });
1487
+ } else {
1488
+ entries.push_back({
1489
+ .binding = 0,
1490
+ .buffer = ggml_webgpu_tensor_buf(src0),
1491
+ .offset = src0_webgpu_tensor_align_offset,
1492
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0),
1493
+ });
1494
+ entries.push_back({
1495
+ .binding = 1,
1496
+ .buffer = ggml_webgpu_tensor_buf(src1),
1497
+ .offset = src1_webgpu_tensor_align_offset,
1498
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1),
1499
+ });
1500
+ if (!flags.inplace && !flags.overlap) {
1501
+ entries.push_back({
1502
+ .binding = 2,
1503
+ .buffer = ggml_webgpu_tensor_buf(dst),
1504
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1505
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst),
1506
+ });
1507
+ }
1508
+ }
1509
+
1510
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1511
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1512
+ }
1513
+
1514
+ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
1515
+ ggml_tensor * src0,
1516
+ ggml_tensor * src1,
1517
+ ggml_tensor * dst) {
1518
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1519
+ uint32_t dim = (uint32_t) dst->op_params[0];
1520
+
1521
+ std::vector<uint32_t> params = {
1522
+ ne,
1523
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1524
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1525
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1526
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
1527
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1528
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1529
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1530
+ (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
1531
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1532
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1533
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1534
+ (uint32_t) dst->ne[0],
1535
+ (uint32_t) dst->ne[1],
1536
+ (uint32_t) dst->ne[2],
1537
+ (uint32_t) dst->ne[3],
1538
+ dim,
1539
+ (uint32_t) src0->ne[dim]
1540
+ };
1541
+
1542
+ std::vector<wgpu::BindGroupEntry> entries = {
1543
+ { .binding = 0,
1544
+ .buffer = ggml_webgpu_tensor_buf(src0),
1545
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1546
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1547
+ { .binding = 1,
1548
+ .buffer = ggml_webgpu_tensor_buf(src1),
1549
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1550
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
1551
+ { .binding = 2,
1552
+ .buffer = ggml_webgpu_tensor_buf(dst),
1553
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1554
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1555
+ };
1556
+
1557
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1558
+ .src0 = src0,
1559
+ .src1 = src1,
1560
+ .dst = dst,
1561
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1562
+ };
1563
+
1564
+ webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
1565
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1566
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1567
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1568
+ }
1569
+
1570
+ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
1571
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1572
+
1573
+ std::vector<uint32_t> params = { ne,
1574
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
1575
+ ggml_type_size(src0->type)),
1576
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1577
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
1578
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1579
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1580
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1581
+ (uint32_t) (src0->ne[0]),
1582
+ (uint32_t) (src0->ne[1]),
1583
+ (uint32_t) (src0->ne[2]),
1584
+ (uint32_t) (src0->ne[3]),
1585
+ (uint32_t) (dst->ne[0]),
1586
+ (uint32_t) (dst->ne[1]),
1587
+ (uint32_t) (dst->ne[2]) };
1588
+
1589
+ std::vector<wgpu::BindGroupEntry> entries = {
1590
+ { .binding = 0,
1591
+ .buffer = ggml_webgpu_tensor_buf(src0),
1592
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1593
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1594
+ { .binding = 1,
1595
+ .buffer = ggml_webgpu_tensor_buf(dst),
1596
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1597
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1598
+ };
1599
+
1600
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1601
+ .src0 = src0,
1602
+ .dst = dst,
1603
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1604
+ };
1605
+
1606
+ webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
1607
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1608
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1609
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1610
+ }
1611
+
1612
+ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1613
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
1614
+
1615
+ std::vector<uint32_t> params = {
1616
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1617
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1618
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1619
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1620
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1621
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1622
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1623
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1624
+ (uint32_t) src->ne[0],
1625
+ (uint32_t) src->ne[1],
1626
+ (uint32_t) src->ne[2],
1627
+ (uint32_t) src->ne[3],
1628
+ *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
1629
+ };
1630
+
1631
+ std::vector<wgpu::BindGroupEntry> entries = {
1632
+ { .binding = 0,
1633
+ .buffer = ggml_webgpu_tensor_buf(src),
1634
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1635
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1636
+ };
1637
+ if (!inplace) {
1638
+ entries.push_back({ .binding = 1,
1639
+ .buffer = ggml_webgpu_tensor_buf(dst),
1640
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1641
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1642
+ }
1643
+
1644
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
1645
+ entries, ggml_nrows(src));
1646
+ }
1647
+
1648
+ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1649
+ ggml_tensor * src0,
1650
+ ggml_tensor * src1,
1651
+ ggml_tensor * src2,
1652
+ ggml_tensor * dst) {
1653
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1654
+ const int has_freq_factor = (src2 != nullptr);
1655
+
1656
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1657
+ const int mode = ((int32_t *) dst->op_params)[2];
1658
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1659
+
1660
+ float freq_base;
1661
+ float freq_scale;
1662
+ float ext_factor;
1663
+ float attn_factor;
1664
+ float beta_fast;
1665
+ float beta_slow;
1666
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1667
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1668
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1669
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1670
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1671
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1672
+
1673
+ int sections[4];
1674
+ memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
1675
+
1676
+ float theta_scale = powf(freq_base, -2.0f / n_dims);
1677
+
1678
+ float corr_dims[2];
1679
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1680
+
1681
+ std::vector<uint32_t> params = {
1682
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1683
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1684
+ src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1685
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1686
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1687
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1688
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1689
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1690
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1691
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1692
+ (uint32_t) ggml_nelements(src0) / 2,
1693
+ (uint32_t) src0->ne[0],
1694
+ (uint32_t) src0->ne[1],
1695
+ (uint32_t) src0->ne[2],
1696
+ (uint32_t) n_dims,
1697
+ (uint32_t) mode,
1698
+ *(uint32_t *) &theta_scale,
1699
+ *(uint32_t *) &attn_factor,
1700
+ *(uint32_t *) &freq_scale,
1701
+ *(uint32_t *) &ext_factor,
1702
+ *(uint32_t *) &corr_dims[0],
1703
+ *(uint32_t *) &corr_dims[1],
1704
+ (uint32_t) sections[0],
1705
+ (uint32_t) sections[1],
1706
+ (uint32_t) sections[2],
1707
+ (uint32_t) sections[3]
1708
+ };
1709
+
1710
+ std::vector<wgpu::BindGroupEntry> entries = {
1711
+ { .binding = 0,
1712
+ .buffer = ggml_webgpu_tensor_buf(src0),
1713
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1714
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1715
+ { .binding = 1,
1716
+ .buffer = ggml_webgpu_tensor_buf(src1),
1717
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1718
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
1719
+ };
1720
+ uint32_t dst_binding = 2;
1721
+ if (has_freq_factor) {
1722
+ dst_binding = 3;
1723
+ entries.push_back({ .binding = 2,
1724
+ .buffer = ggml_webgpu_tensor_buf(src2),
1725
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1726
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1727
+ }
1728
+ if (!inplace) {
1729
+ entries.push_back({ .binding = dst_binding,
1730
+ .buffer = ggml_webgpu_tensor_buf(dst),
1731
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1732
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1733
+ }
1734
+
1735
+ webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
1736
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1737
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1738
+ }
1739
+
1740
+ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
1741
+ const int split = (src1 != nullptr);
1742
+
1743
+ std::vector<uint32_t> params = {
1744
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1745
+ src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1746
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1747
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1748
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1749
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1750
+ src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
1751
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1752
+ src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
1753
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1754
+ src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
1755
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1756
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1757
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1758
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1759
+ (uint32_t) ggml_nelements(dst),
1760
+ (uint32_t) dst->ne[0],
1761
+ (uint32_t) dst->ne[1],
1762
+ (uint32_t) dst->ne[2],
1763
+ (uint32_t) ((int32_t *) dst->op_params)[1], // swapped
1764
+ *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
1765
+ *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
1766
+ };
1767
+
1768
+ std::vector<wgpu::BindGroupEntry> entries = {
1769
+ { .binding = 0,
1770
+ .buffer = ggml_webgpu_tensor_buf(src0),
1771
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1772
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1773
+ };
1774
+ uint32_t dst_binding = 1;
1775
+ if (split) {
1776
+ dst_binding = 2;
1777
+ entries.push_back({ .binding = 1,
1778
+ .buffer = ggml_webgpu_tensor_buf(src1),
1779
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1780
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1781
+ }
1782
+ entries.push_back({ .binding = dst_binding,
1783
+ .buffer = ggml_webgpu_tensor_buf(dst),
1784
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1785
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1786
+
1787
+ webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
1788
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1789
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1790
+ }
1791
+
1792
+ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1793
+ bool inplace = ggml_webgpu_tensor_equal(src, dst);
1794
+
1795
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1796
+ .src0 = src,
1797
+ .src1 = nullptr,
1798
+ .dst = dst,
1799
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1800
+ .inplace = inplace,
1801
+ };
1802
+
1803
+ webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
1804
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1805
+
1806
+ // params unchanged
1807
+ std::vector<uint32_t> params = {
1808
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1809
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1810
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1811
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1812
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1813
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1814
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1815
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1816
+ (uint32_t) ggml_nelements(dst),
1817
+ (uint32_t) src->ne[0],
1818
+ (uint32_t) src->ne[1],
1819
+ (uint32_t) src->ne[2],
1820
+ *(uint32_t *) dst->op_params, // scale
1821
+ *(uint32_t *) &dst->op_params[1] // bias
1822
+ };
1823
+
1824
+ // bindgroups unchanged
1825
+ std::vector<wgpu::BindGroupEntry> entries = {
1826
+ { .binding = 0,
1827
+ .buffer = ggml_webgpu_tensor_buf(src),
1828
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1829
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1830
+ };
1831
+
1832
+ if (!inplace) {
1833
+ entries.push_back({ .binding = 1,
1834
+ .buffer = ggml_webgpu_tensor_buf(dst),
1835
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1836
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1837
+ }
1838
+
1839
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
1840
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1841
+ }
1842
+
1843
+ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
1844
+ ggml_tensor * src0,
1845
+ ggml_tensor * src1,
1846
+ ggml_tensor * src2,
1847
+ ggml_tensor * dst) {
1848
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1849
+ const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
1850
+ const int has_sink = (src2 != nullptr);
1851
+ float max_bias;
1852
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1853
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
1854
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1855
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1856
+
1857
+ std::vector<uint32_t> params = {
1858
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1859
+ mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1860
+ has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1861
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1862
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1863
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1864
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1865
+ mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
1866
+ mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
1867
+ mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
1868
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1869
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1870
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1871
+ (uint32_t) ggml_nelements(dst),
1872
+ (uint32_t) src0->ne[0],
1873
+ (uint32_t) src0->ne[1],
1874
+ (uint32_t) src0->ne[2],
1875
+ mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
1876
+ mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
1877
+ *(uint32_t *) dst->op_params, // scale
1878
+ *(uint32_t *) &max_bias,
1879
+ *(uint32_t *) &n_head_log2,
1880
+ *(uint32_t *) &m0,
1881
+ *(uint32_t *) &m1
1882
+ };
1883
+
1884
+ std::vector<wgpu::BindGroupEntry> entries = {
1885
+ { .binding = 0,
1886
+ .buffer = ggml_webgpu_tensor_buf(src0),
1887
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1888
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
1889
+ };
1890
+ uint32_t binding_num = 1;
1891
+ if (mask_type < 2) {
1892
+ entries.push_back({ .binding = binding_num,
1893
+ .buffer = ggml_webgpu_tensor_buf(src1),
1894
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1895
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1896
+ binding_num++;
1897
+ }
1898
+ if (has_sink) {
1899
+ entries.push_back({ .binding = binding_num,
1900
+ .buffer = ggml_webgpu_tensor_buf(src2),
1901
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1902
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1903
+ binding_num++;
1904
+ }
1905
+ if (!inplace) {
1906
+ entries.push_back({ .binding = binding_num,
1907
+ .buffer = ggml_webgpu_tensor_buf(dst),
1908
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1909
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1910
+ }
1911
+
1912
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
1913
+ ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1914
+ ggml_nrows(dst));
1915
+ }
1916
+
1917
+ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1918
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1919
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1920
+ (uint32_t) src->ne[0] };
1921
+
1922
+ std::vector<wgpu::BindGroupEntry> entries = {
1923
+ { .binding = 0,
1924
+ .buffer = ggml_webgpu_tensor_buf(src),
1925
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1926
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1927
+ { .binding = 1,
1928
+ .buffer = ggml_webgpu_tensor_buf(dst),
1929
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1930
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1931
+ };
1932
+
1933
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1934
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1935
+ };
1936
+
1937
+ webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
1938
+ uint32_t wg_x = ggml_nelements(dst);
1939
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1940
+ }
1941
+
1942
+ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1943
+ bool is_top_k = dst->op == GGML_OP_TOP_K;
1944
+
1945
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1946
+ .src0 = src,
1947
+ .src1 = nullptr,
1948
+ .dst = dst,
1949
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1950
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1951
+ };
1952
+
1953
+ webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
1954
+ auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
1955
+
1956
+ webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
1957
+
1958
+ const uint32_t src_ne0 = (uint32_t) src->ne[0];
1959
+ const uint32_t nrows = (uint32_t) ggml_nrows(src);
1960
+ const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
1961
+ const uint32_t block_size =
1962
+ is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
1963
+ uint32_t out_ne0 = src_ne0;
1964
+ if (is_top_k) {
1965
+ if (npr > 1) {
1966
+ const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
1967
+ out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
1968
+ } else {
1969
+ out_ne0 = block_size;
1970
+ }
1971
+ }
1972
+
1973
+ uint32_t merge_len = block_size;
1974
+ uint32_t merge_passes = 0;
1975
+ while (merge_len < out_ne0) {
1976
+ merge_len <<= 1;
1977
+ merge_passes++;
1978
+ }
1979
+
1980
+ const bool start_in_tmp = (merge_passes % 2) == 1;
1981
+
1982
+ const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
1983
+ const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
1984
+ const size_t tmp_offset =
1985
+ ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1986
+ const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
1987
+ const size_t dst_binding_size =
1988
+ ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
1989
+
1990
+ const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
1991
+ const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
1992
+ const uint32_t offset_tmp = 0;
1993
+ const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
1994
+ const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
1995
+ const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
1996
+ const uint32_t stride_idx1 = out_ne0;
1997
+ const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
1998
+ const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
1999
+
2000
+ std::vector<webgpu_pipeline> pipelines;
2001
+ std::vector<std::vector<uint32_t>> params_list;
2002
+ std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
2003
+ std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
2004
+
2005
+ const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst;
2006
+ const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
2007
+ const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
2008
+
2009
+ std::vector<uint32_t> init_params = {
2010
+ offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1,
2011
+ stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
2012
+ block_size, npr, nrows
2013
+ };
2014
+
2015
+ const uint32_t total_wg_init = npr * nrows;
2016
+ const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2017
+ const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
2018
+ const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
2019
+ std::vector<wgpu::BindGroupEntry> init_entries = {
2020
+ { .binding = 0,
2021
+ .buffer = ggml_webgpu_tensor_buf(src),
2022
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2023
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2024
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
2025
+ };
2026
+
2027
+ pipelines.push_back(argsort_pipeline);
2028
+ params_list.push_back(std::move(init_params));
2029
+ entries_list.push_back(std::move(init_entries));
2030
+ workgroups_list.push_back({ wg_x_init, wg_y_init });
2031
+
2032
+ if (merge_passes == 0) {
2033
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
2034
+ entries_list, workgroups_list);
2035
+ }
2036
+
2037
+ bool in_is_tmp = start_in_tmp;
2038
+ uint32_t len = block_size;
2039
+ while (len < out_ne0) {
2040
+ const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
2041
+
2042
+ const bool out_is_tmp = !in_is_tmp;
2043
+ const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst;
2044
+ const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst;
2045
+ const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
2046
+ const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
2047
+ const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size;
2048
+ const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size;
2049
+ const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
2050
+ const uint32_t stride_out1 = top_k_out;
2051
+ const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
2052
+ const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
2053
+
2054
+ std::vector<uint32_t> merge_params = { offset_src,
2055
+ offset_in,
2056
+ offset_out,
2057
+ stride_src1,
2058
+ stride_src2,
2059
+ stride_src3,
2060
+ stride_idx1,
2061
+ stride_idx2,
2062
+ stride_idx3,
2063
+ stride_out1,
2064
+ stride_out2,
2065
+ stride_out3,
2066
+ out_ne0,
2067
+ (uint32_t) src->ne[1],
2068
+ (uint32_t) src->ne[2],
2069
+ top_k_out,
2070
+ len,
2071
+ nm,
2072
+ nrows };
2073
+
2074
+ std::vector<wgpu::BindGroupEntry> merge_entries = {
2075
+ { .binding = 0,
2076
+ .buffer = ggml_webgpu_tensor_buf(src),
2077
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2078
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2079
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
2080
+ { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
2081
+ };
2082
+
2083
+ const uint32_t total_wg_merge = nm * nrows;
2084
+ const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
2085
+ const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
2086
+ workgroups_list.push_back({ wg_x_merge, wg_y_merge });
2087
+ pipelines.push_back(argsort_merge_pipeline);
2088
+ params_list.push_back(std::move(merge_params));
2089
+ entries_list.push_back(std::move(merge_entries));
2090
+
2091
+ len <<= 1;
2092
+ in_is_tmp = !in_is_tmp;
686
2093
  }
687
2094
 
688
- size_t max_wg_size = ctx->max_wg_size_x;
689
- uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
690
- ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
2095
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
2096
+ workgroups_list);
691
2097
  }
692
2098
 
693
- static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
694
- bool in_place = ggml_webgpu_tensor_equal(src, dst);
2099
+ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
2100
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2101
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2102
+ (uint32_t) src->ne[0] };
695
2103
 
696
- uint32_t eps;
697
- memcpy(&eps, dst->op_params, sizeof(float));
2104
+ std::vector<wgpu::BindGroupEntry> entries = {
2105
+ { .binding = 0,
2106
+ .buffer = ggml_webgpu_tensor_buf(src),
2107
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2108
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2109
+ { .binding = 1,
2110
+ .buffer = ggml_webgpu_tensor_buf(dst),
2111
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
2112
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
2113
+ };
698
2114
 
699
- std::vector<uint32_t> params = {
700
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2115
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
2116
+ .src0 = src,
2117
+ .src1 = nullptr,
2118
+ .dst = dst,
2119
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
701
2120
  };
702
- if (!in_place) {
703
- params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
704
- }
705
- params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
706
- params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
707
- params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
708
- if (!in_place) {
709
- params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
710
- params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
711
- params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
712
- }
713
- params.push_back((uint32_t) src->ne[0]);
714
- params.push_back((uint32_t) src->ne[1]);
715
- params.push_back((uint32_t) src->ne[2]);
716
- params.push_back((uint32_t) src->ne[3]);
717
- params.push_back(eps); // epsilon, will be bitcast to float in shader
2121
+
2122
+ webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
2123
+ uint32_t wg_x = ggml_nrows(dst);
2124
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
2125
+ }
2126
+
2127
+ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
2128
+ bool total_sum = dst->op == GGML_OP_SUM;
2129
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2130
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2131
+ total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
2132
+ total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
2133
+ total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
2134
+ total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
2135
+ total_sum ? 1 : (uint32_t) src->ne[1],
2136
+ total_sum ? 1 : (uint32_t) src->ne[2] };
718
2137
 
719
2138
  std::vector<wgpu::BindGroupEntry> entries = {
720
2139
  { .binding = 0,
721
2140
  .buffer = ggml_webgpu_tensor_buf(src),
722
2141
  .offset = ggml_webgpu_tensor_align_offset(ctx, src),
723
- .size = ggml_webgpu_tensor_binding_size(ctx, src) }
2142
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2143
+ { .binding = 1,
2144
+ .buffer = ggml_webgpu_tensor_buf(dst),
2145
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
2146
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
724
2147
  };
725
- if (!in_place) {
726
- entries.push_back({ .binding = 1,
727
- .buffer = ggml_webgpu_tensor_buf(dst),
728
- .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
729
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
730
- }
731
2148
 
732
- wgpu::ComputePipeline pipeline;
733
- if (in_place) {
734
- pipeline = ctx->rms_norm_ip_pipeline;
735
- } else {
736
- pipeline = ctx->rms_norm_pipeline;
737
- }
738
- size_t max_wg_size = ctx->max_wg_size_x;
739
- uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
740
- ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
2149
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
2150
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
2151
+ };
2152
+
2153
+ webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
2154
+
2155
+ uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
2156
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
741
2157
  }
742
2158
 
743
- // Returns true if node has enqueued work into the queue, false otherwise
744
- static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
2159
+ // Returns the encoded command, or std::nullopt if the operation is a no-op
2160
+ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
745
2161
  if (ggml_is_empty(node)) {
746
- return false;
2162
+ return std::nullopt;
2163
+ }
2164
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
2165
+ return std::nullopt;
747
2166
  }
748
2167
  WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
749
2168
 
750
2169
  ggml_tensor * src0 = node->src[0];
751
2170
  ggml_tensor * src1 = node->src[1];
2171
+ ggml_tensor * src2 = node->src[2];
752
2172
 
753
2173
  switch (node->op) {
754
2174
  // no-ops
@@ -757,55 +2177,122 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
757
2177
  case GGML_OP_PERMUTE:
758
2178
  case GGML_OP_TRANSPOSE:
759
2179
  case GGML_OP_RESHAPE:
760
- return false;
2180
+ return std::nullopt;
761
2181
  case GGML_OP_CPY:
762
- ggml_webgpu_cpy(ctx, src0, node);
763
- break;
2182
+ case GGML_OP_CONT:
2183
+ return ggml_webgpu_cpy(ctx, src0, node);
764
2184
  case GGML_OP_SET_ROWS:
765
- ggml_webgpu_set_rows(ctx, src0, src1, node);
766
- break;
2185
+ return ggml_webgpu_set_rows(ctx, src0, src1, node);
767
2186
  case GGML_OP_GET_ROWS:
768
- ggml_webgpu_get_rows(ctx, src0, src1, node);
769
- break;
2187
+ return ggml_webgpu_get_rows(ctx, src0, src1, node);
770
2188
  case GGML_OP_MUL_MAT:
771
- ggml_webgpu_mul_mat(ctx, src0, src1, node);
772
- break;
2189
+ return ggml_webgpu_mul_mat(ctx, src0, src1, node);
2190
+ case GGML_OP_FLASH_ATTN_EXT:
2191
+ #ifndef __EMSCRIPTEN__
2192
+ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
2193
+ #else
2194
+ return std::nullopt;
2195
+ #endif
773
2196
  case GGML_OP_ADD:
774
- if (ggml_webgpu_tensor_equal(src0, node)) {
775
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
776
- } else {
777
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
778
- }
779
- break;
2197
+ case GGML_OP_SUB:
780
2198
  case GGML_OP_MUL:
781
- if (ggml_webgpu_tensor_equal(src0, node)) {
782
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
783
- } else {
784
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
785
- }
786
- break;
2199
+ case GGML_OP_DIV:
2200
+ return ggml_webgpu_binary_op(ctx, src0, src1, node);
2201
+ case GGML_OP_CONCAT:
2202
+ return ggml_webgpu_concat(ctx, src0, src1, node);
2203
+ case GGML_OP_REPEAT:
2204
+ return ggml_webgpu_repeat(ctx, src0, node);
787
2205
  case GGML_OP_RMS_NORM:
788
- ggml_webgpu_rms_norm(ctx, src0, node);
789
- break;
2206
+ return ggml_webgpu_rms_norm(ctx, src0, node);
2207
+ case GGML_OP_ROPE:
2208
+ return ggml_webgpu_rope(ctx, src0, src1, src2, node);
2209
+ case GGML_OP_GLU:
2210
+ return ggml_webgpu_glu(ctx, src0, src1, node);
2211
+ case GGML_OP_SCALE:
2212
+ return ggml_webgpu_scale(ctx, src0, node);
2213
+ case GGML_OP_SOFT_MAX:
2214
+ return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
2215
+ case GGML_OP_UNARY:
2216
+ case GGML_OP_CLAMP:
2217
+ case GGML_OP_FILL:
2218
+ case GGML_OP_LOG:
2219
+ case GGML_OP_SQR:
2220
+ case GGML_OP_SQRT:
2221
+ case GGML_OP_SIN:
2222
+ case GGML_OP_COS:
2223
+ return ggml_webgpu_unary_op(ctx, src0, node);
2224
+ case GGML_OP_PAD:
2225
+ return ggml_webgpu_pad(ctx, src0, node);
2226
+ case GGML_OP_ARGMAX:
2227
+ return ggml_webgpu_argmax(ctx, src0, node);
2228
+ case GGML_OP_ARGSORT:
2229
+ case GGML_OP_TOP_K:
2230
+ // we reuse the same argsort implementation for top_k
2231
+ return ggml_webgpu_argsort(ctx, src0, node);
2232
+ case GGML_OP_CUMSUM:
2233
+ return ggml_webgpu_cumsum(ctx, src0, node);
2234
+ case GGML_OP_SUM:
2235
+ case GGML_OP_SUM_ROWS:
2236
+ return ggml_webgpu_sum_rows(ctx, src0, node);
790
2237
  default:
791
- return false;
2238
+ return std::nullopt;
792
2239
  }
793
- return true;
794
2240
  }
795
2241
 
796
2242
  static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
797
2243
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
798
2244
 
799
- ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
2245
+ ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
800
2246
  webgpu_context ctx = backend_ctx->webgpu_ctx;
801
2247
 
2248
+ WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
2249
+
2250
+ std::vector<webgpu_command> commands;
2251
+ std::vector<webgpu_submission> subs;
2252
+ uint32_t num_batched_kernels = 0;
2253
+ bool contains_set_rows = false;
2254
+
802
2255
  for (int i = 0; i < cgraph->n_nodes; i++) {
803
- ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
2256
+ if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
2257
+ contains_set_rows = true;
2258
+ }
2259
+ if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
2260
+ commands.push_back(*cmd);
2261
+ num_batched_kernels += cmd.value().num_kernels;
2262
+ }
2263
+
2264
+ if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
2265
+ num_batched_kernels = 0;
2266
+ subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
2267
+ // Process events and check for completed submissions
2268
+ ctx->global_ctx->instance.ProcessEvents();
2269
+ ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
2270
+ commands.clear();
2271
+ }
2272
+ }
2273
+ if (!commands.empty()) {
2274
+ subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
2275
+ commands.clear();
804
2276
  }
805
2277
 
806
- ggml_backend_webgpu_submit_queue(ctx);
807
- ggml_backend_webgpu_wait_on_submission(ctx);
2278
+ // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
2279
+ if (contains_set_rows) {
2280
+ wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
2281
+ encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
2282
+ ctx->set_rows_host_error_buf.GetSize());
2283
+ wgpu::CommandBuffer set_rows_commands = encoder.Finish();
2284
+ ctx->global_ctx->queue.Submit(1, &set_rows_commands);
2285
+ ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
2286
+ ctx->set_rows_host_error_buf.GetSize());
2287
+ const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
2288
+ if (*error_data) {
2289
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
2290
+ }
2291
+ ctx->set_rows_host_error_buf.Unmap();
2292
+ }
808
2293
 
2294
+ ggml_backend_webgpu_wait(ctx->global_ctx, subs);
2295
+ WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
809
2296
  return GGML_STATUS_SUCCESS;
810
2297
  }
811
2298
 
@@ -831,9 +2318,11 @@ static ggml_backend_i ggml_backend_webgpu_i = {
831
2318
  /* GGML Backend Buffer Interface */
832
2319
 
833
2320
  static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
834
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
835
2321
  ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
836
- ctx->buffer.Destroy();
2322
+ if (ctx != nullptr && ctx->buffer != nullptr) {
2323
+ ctx->buffer.Destroy();
2324
+ delete ctx;
2325
+ }
837
2326
  }
838
2327
 
839
2328
  // Returns the "fake" base pointer.
@@ -848,20 +2337,25 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
848
2337
  size_t offset,
849
2338
  size_t size) {
850
2339
  if (size == 0) {
851
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
2340
+ WEBGPU_LOG_DEBUG(
2341
+ "ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
2342
+ "nothing to do.");
852
2343
  return;
853
2344
  }
854
2345
 
855
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
856
- << offset << ", " << size << ")");
2346
+ WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
857
2347
 
858
2348
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
859
2349
 
2350
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
2351
+ << ", " << offset << ", " << size << ")");
2352
+
860
2353
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
861
2354
 
862
2355
  // This is a trick to set all bytes of a u32 to the same 1 byte value.
863
2356
  uint32_t val32 = (uint32_t) value * 0x01010101;
864
- ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
2357
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
2358
+ WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
865
2359
  }
866
2360
 
867
2361
  static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
@@ -869,14 +2363,15 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
869
2363
  const void * data,
870
2364
  size_t offset,
871
2365
  size_t size) {
872
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
873
- << offset << ", " << size << ")");
874
- ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
875
- webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
2366
+ WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
2367
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2368
+
2369
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2370
+ << ", " << offset << ", " << size << ")");
876
2371
 
877
2372
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
878
2373
 
879
- webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
2374
+ buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
880
2375
 
881
2376
  if (size % 4 != 0) {
882
2377
  // If size is not a multiple of 4, we need to memset the remaining bytes
@@ -889,12 +2384,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
889
2384
  ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
890
2385
  }
891
2386
  // memset the remaining bytes
892
- ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
893
- remaining_size);
2387
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
2388
+ total_offset + (size - remaining_size), remaining_size);
894
2389
  } else {
895
2390
  // wait for WriteBuffer to complete
896
- ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
2391
+ buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
2392
+ wgpu::CallbackMode::AllowSpontaneous,
2393
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
2394
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
2395
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
2396
+ std::string(message).c_str());
2397
+ }
2398
+ }),
2399
+ UINT64_MAX);
897
2400
  }
2401
+ WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
898
2402
  }
899
2403
 
900
2404
  static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
@@ -902,54 +2406,60 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
902
2406
  void * data,
903
2407
  size_t offset,
904
2408
  size_t size) {
905
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
906
- << offset << ", " << size << ")");
907
-
908
- ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
909
- webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
910
- wgpu::Device device = webgpu_ctx->device;
2409
+ WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
2410
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2411
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2412
+ << ", " << offset << ", " << size << ")");
2413
+ wgpu::Device device = buf_ctx->global_ctx->device;
911
2414
 
912
2415
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
913
2416
 
914
2417
  size_t final_size = size;
915
2418
  if (size % 4 != 0) {
916
- // If size is not a multiple of 4, we need to round it up to the next multiple of 4
2419
+ // If size is not a multiple of 4, we need to round it up to the next
2420
+ // multiple of 4
917
2421
  final_size = size + (4 - (size % 4));
918
2422
  }
919
2423
 
920
- std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
2424
+ std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
921
2425
 
922
- if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
2426
+ if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
2427
+ buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
923
2428
  // Create a new staging buffer if it doesn't exist or is too small
924
- if (webgpu_ctx->get_tensor_staging_buf) {
925
- webgpu_ctx->get_tensor_staging_buf.Destroy();
2429
+ if (buf_ctx->global_ctx->get_tensor_staging_buf) {
2430
+ buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
926
2431
  }
927
- ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
2432
+ ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
928
2433
  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
929
2434
  }
930
2435
 
931
2436
  // Copy the data from the buffer to the staging buffer
932
2437
  wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
933
- encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
2438
+ encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
2439
+ final_size);
934
2440
  wgpu::CommandBuffer commands = encoder.Finish();
935
2441
 
936
2442
  // Submit the command buffer to the queue
937
- webgpu_ctx->queue.Submit(1, &commands);
2443
+ buf_ctx->global_ctx->queue.Submit(1, &commands);
938
2444
 
939
2445
  // Map the staging buffer to read the data
940
- ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
2446
+ ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
2447
+ wgpu::MapMode::Read, 0, final_size);
941
2448
  // Must specify size here since the staging buffer might be larger than the tensor size
942
- const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
2449
+ const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
943
2450
 
944
2451
  // Copy the data from the mapped range to the output buffer
945
2452
  std::memcpy(data, mapped_range, size);
946
- webgpu_ctx->get_tensor_staging_buf.Unmap();
2453
+ buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
2454
+ WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
947
2455
  }
948
2456
 
949
2457
  static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
950
2458
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
2459
+ WEBGPU_CPU_PROFILE_TOTAL_START(clear);
951
2460
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
952
- ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
2461
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
2462
+ WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
953
2463
  }
954
2464
 
955
2465
  static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
@@ -961,7 +2471,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
961
2471
  /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
962
2472
  /* .cpy_tensor = */ NULL, // TODO: optional, implement this
963
2473
  /* .clear = */ ggml_backend_webgpu_buffer_clear,
964
- /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
2474
+ /* .reset = */ NULL, // TODO: optional, think it coordinates with
2475
+ // .init_tensor
965
2476
  };
966
2477
 
967
2478
  /* End GGML Backend Buffer Interface */
@@ -975,29 +2486,61 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
975
2486
 
976
2487
  static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
977
2488
  size_t size) {
978
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
979
- ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2489
+ static std::atomic<int> buffer_count;
2490
+ int buffer_id = buffer_count++;
2491
+ std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
2492
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
980
2493
 
981
- wgpu::Buffer buf;
982
- ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
983
- (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
2494
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2495
+ wgpu::Buffer buf;
2496
+ ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
984
2497
  wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
985
- "allocated_buffer");
2498
+ buf_name.c_str());
986
2499
 
987
- ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
2500
+ ggml_backend_webgpu_buffer_context * buf_ctx =
2501
+ new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
988
2502
 
989
2503
  return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
990
2504
  }
991
2505
 
992
2506
  static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
993
- ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
994
- return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
2507
+ ggml_backend_webgpu_device_context * dev_ctx =
2508
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2509
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
995
2510
  }
996
2511
 
997
- // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
2512
+ // maxBufferSize might be larger, but you can't bind more than
2513
+ // maxStorageBufferBindingSize to a single binding.
998
2514
  static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
2515
+ ggml_backend_webgpu_device_context * dev_ctx =
2516
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2517
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
2518
+ }
2519
+
2520
+ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
2521
+ const ggml_tensor * tensor) {
999
2522
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1000
- return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
2523
+ size_t res = ggml_nbytes(tensor);
2524
+ switch (tensor->op) {
2525
+ case GGML_OP_ARGSORT:
2526
+ res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2527
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
2528
+ break;
2529
+ case GGML_OP_TOP_K:
2530
+ {
2531
+ const ggml_tensor * src0 = tensor->src[0];
2532
+ if (src0) {
2533
+ const size_t full = sizeof(int32_t) * ggml_nelements(src0);
2534
+ res = ROUNDUP_POW2(
2535
+ full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2536
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
2537
+ }
2538
+ }
2539
+ break;
2540
+ default:
2541
+ break;
2542
+ }
2543
+ return res;
1001
2544
  }
1002
2545
 
1003
2546
  /* End GGML Backend Buffer Type Interface */
@@ -1016,9 +2559,18 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
1016
2559
 
1017
2560
  static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1018
2561
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1019
- // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
1020
- *free = ctx->webgpu_ctx->limits.maxBufferSize;
1021
- *total = ctx->webgpu_ctx->limits.maxBufferSize;
2562
+ // TODO: for now, return maxBufferSize as both free and total memory
2563
+ // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
2564
+ uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
2565
+ // If we're on a 32-bit system, clamp to UINTPTR_MAX
2566
+ #if UINTPTR_MAX < UINT64_MAX
2567
+ uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
2568
+ if (max_buffer_size > max_ptr_size) {
2569
+ max_buffer_size = max_ptr_size;
2570
+ }
2571
+ #endif
2572
+ *free = static_cast<size_t>(max_buffer_size);
2573
+ *total = static_cast<size_t>(max_buffer_size);
1022
2574
  }
1023
2575
 
1024
2576
  static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
@@ -1044,205 +2596,382 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
1044
2596
  return reinterpret_cast<ggml_guid_t>((void *) guid_str);
1045
2597
  }
1046
2598
 
1047
- // The max workgroup size is a common constant
1048
- static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
1049
- std::vector<wgpu::ConstantEntry> constants(1);
1050
- constants[0].key = "wg_size";
1051
- constants[0].value = webgpu_ctx->max_wg_size_x;
1052
- return constants;
1053
- }
1054
-
1055
- static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
2599
+ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
1056
2600
  // we use the maximum workgroup size for the memset pipeline
1057
- size_t max_wg_size = webgpu_ctx->max_wg_size_x;
1058
- size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
2601
+ size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1059
2602
  // Size the bytes_per_thread so that the largest buffer size can be handled
1060
- webgpu_ctx->memset_bytes_per_thread =
1061
- (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
2603
+ ctx->capabilities.memset_bytes_per_thread =
2604
+ CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
1062
2605
  std::vector<wgpu::ConstantEntry> constants(2);
1063
- constants[0].key = "wg_size";
1064
- constants[0].value = max_wg_size;
1065
- constants[1].key = "bytes_per_thread";
1066
- constants[1].value = webgpu_ctx->memset_bytes_per_thread;
1067
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
1068
- }
1069
-
1070
- static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
1071
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
1072
- wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
1073
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
1074
- wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
1075
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
1076
- wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
1077
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
1078
- wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
1079
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
1080
- wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
1081
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
1082
- wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
1083
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
1084
- wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
1085
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
1086
- wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
1087
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
1088
- wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
1089
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
1090
- wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
1091
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
1092
- wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
1093
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
1094
- wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
1095
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
1096
- wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
1097
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
1098
- wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
1099
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
1100
- wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
1101
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
1102
- wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
1103
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
1104
- wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
1105
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
1106
- wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
1107
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
1108
- wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
1109
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
1110
- wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
1111
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
1112
- wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
1113
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
1114
- wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
1115
- }
1116
-
1117
- static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
1118
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
1119
- ggml_webgpu_max_wg_size_entry(webgpu_ctx));
1120
- }
1121
-
1122
- static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
1123
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1124
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
1125
- "get_rows_f32_vec", constants);
1126
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
1127
- "get_rows_f32", constants);
1128
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
1129
- "get_rows_f16", constants);
1130
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
1131
- "get_rows_i32", constants);
1132
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
1133
- "get_rows_q4_0", constants);
1134
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
1135
- "get_rows_q4_1", constants);
1136
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
1137
- "get_rows_q5_0", constants);
1138
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
1139
- "get_rows_q5_1", constants);
1140
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
1141
- "get_rows_q8_0", constants);
1142
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
1143
- "get_rows_q2_k", constants);
1144
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
1145
- "get_rows_q3_k", constants);
1146
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
1147
- "get_rows_q4_k", constants);
1148
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
1149
- "get_rows_q5_k", constants);
1150
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
1151
- "get_rows_q6_k", constants);
1152
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
1153
- wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
1154
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
1155
- wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
1156
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s,
1157
- "get_rows_iq2_s", constants);
1158
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
1159
- wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
1160
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s,
1161
- "get_rows_iq3_s", constants);
1162
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s,
1163
- "get_rows_iq1_s", constants);
1164
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m,
1165
- "get_rows_iq1_m", constants);
1166
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
1167
- wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
1168
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
1169
- wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
2606
+ constants[0].key = "wg_size";
2607
+ constants[0].value = WEBGPU_MAX_WG_SIZE;
2608
+ constants[1].key = "bytes_per_thread";
2609
+ constants[1].value = ctx->capabilities.memset_bytes_per_thread;
2610
+ ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
1170
2611
  }
1171
2612
 
1172
2613
  static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
1173
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
1174
- ggml_webgpu_max_wg_size_entry(webgpu_ctx));
1175
- }
1176
-
1177
- static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
1178
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1179
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
1180
- constants);
1181
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
1182
- constants);
1183
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
1184
- "add_in_place_f32", constants);
1185
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
1186
- "add_in_place_f16", constants);
1187
- }
1188
-
1189
- static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
1190
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1191
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
1192
- constants);
1193
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
1194
- constants);
1195
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
1196
- "mul_in_place_f32", constants);
1197
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
1198
- "mul_in_place_f16", constants);
2614
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2615
+
2616
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
2617
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2618
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
2619
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
2620
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
2621
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2622
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
2623
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2624
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
2625
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
1199
2626
  }
1200
2627
 
1201
2628
  static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
1202
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1203
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
1204
- constants);
1205
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
1206
- "rms_norm_in_place", constants);
2629
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2630
+
2631
+ webgpu_ctx->rms_norm_pipelines[0] =
2632
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2633
+ webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
2634
+ webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2635
+ }
2636
+
2637
+ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
2638
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2639
+
2640
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
2641
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2642
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
2643
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2644
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
2645
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2646
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
2647
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2648
+
2649
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
2650
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2651
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
2652
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2653
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
2654
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2655
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
2656
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2657
+ }
2658
+
2659
+ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
2660
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2661
+
2662
+ // REGLU
2663
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
2664
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2665
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
2666
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2667
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
2668
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2669
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
2670
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2671
+
2672
+ // GEGLU
2673
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
2674
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2675
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
2676
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2677
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
2678
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2679
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
2680
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2681
+
2682
+ // SWIGLU
2683
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
2684
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2685
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
2686
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2687
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2688
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2689
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2690
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2691
+
2692
+ // SWIGLU_OAI
2693
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
2694
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2695
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2696
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2697
+
2698
+ // GEGLU_ERF
2699
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
2700
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2701
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
2702
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2703
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2704
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2705
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2706
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2707
+
2708
+ // GEGLU_QUICK
2709
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
2710
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2711
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
2712
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2713
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2714
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2715
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2716
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2717
+ }
2718
+
2719
+ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
2720
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2721
+
2722
+ // f32 (no mask)
2723
+ webgpu_ctx->soft_max_pipelines[2][0][0] =
2724
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2725
+ webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
2726
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2727
+ webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
2728
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2729
+ webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
2730
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2731
+
2732
+ // f32 mask (mask_type = 0)
2733
+ webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
2734
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2735
+ webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
2736
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2737
+ webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
2738
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2739
+ webgpu_ctx->soft_max_pipelines[0][1][1] =
2740
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
2741
+ "soft_max_f32_mask_f32_sink_inplace", constants);
2742
+
2743
+ // f16 mask (mask_type = 1)
2744
+ webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
2745
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2746
+ webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
2747
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2748
+ webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
2749
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2750
+ webgpu_ctx->soft_max_pipelines[1][1][1] =
2751
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
2752
+ "soft_max_f32_mask_f16_sink_inplace", constants);
2753
+ }
2754
+
2755
+ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
2756
+ wgpu::RequestAdapterOptions options = {};
2757
+
2758
+ #ifndef __EMSCRIPTEN__
2759
+ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2760
+ const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2761
+ wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2762
+ adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
2763
+ adapterTogglesDesc.enabledToggleCount = 2;
2764
+ options.nextInChain = &adapterTogglesDesc;
2765
+ #endif
2766
+
2767
+ ctx->webgpu_global_ctx->instance.WaitAny(
2768
+ ctx->webgpu_global_ctx->instance.RequestAdapter(
2769
+ &options, wgpu::CallbackMode::AllowSpontaneous,
2770
+ [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
2771
+ if (status != wgpu::RequestAdapterStatus::Success) {
2772
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
2773
+ return;
2774
+ }
2775
+ ctx->webgpu_global_ctx->adapter = std::move(adapter);
2776
+ }),
2777
+ UINT64_MAX);
2778
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
2779
+
2780
+ ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
2781
+
2782
+ wgpu::AdapterInfo info{};
2783
+ #ifndef __EMSCRIPTEN__
2784
+ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2785
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2786
+ info.nextInChain = &subgroup_matrix_configs;
2787
+ }
2788
+ #endif
2789
+ ctx->webgpu_global_ctx->adapter.GetInfo(&info);
2790
+ wgpu::SupportedFeatures features;
2791
+ ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
2792
+ // we require f16 support
2793
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2794
+
2795
+ #ifndef __EMSCRIPTEN__
2796
+ // Only support square f16 matrices of size 8 or 16 for now
2797
+ bool valid_subgroup_matrix_config = false;
2798
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2799
+ for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2800
+ const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2801
+ if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2802
+ config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2803
+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2804
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
2805
+ ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
2806
+ ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
2807
+ valid_subgroup_matrix_config = true;
2808
+ break;
2809
+ }
2810
+ }
2811
+ }
2812
+ ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
2813
+ #endif
2814
+
2815
+ // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2816
+ // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2817
+ ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
2818
+ // Initialize device
2819
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
2820
+
2821
+ #ifndef __EMSCRIPTEN__
2822
+ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2823
+ if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
2824
+ required_features.push_back(wgpu::FeatureName::Subgroups);
2825
+ required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2826
+ }
2827
+ #endif
2828
+
2829
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2830
+ required_features.push_back(wgpu::FeatureName::TimestampQuery);
2831
+ #endif
2832
+
2833
+ wgpu::DeviceDescriptor dev_desc;
2834
+ dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits;
2835
+ dev_desc.requiredFeatures = required_features.data();
2836
+ dev_desc.requiredFeatureCount = required_features.size();
2837
+ dev_desc.SetDeviceLostCallback(
2838
+ wgpu::CallbackMode::AllowSpontaneous,
2839
+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
2840
+ if (reason == wgpu::DeviceLostReason::Destroyed) {
2841
+ return;
2842
+ }
2843
+ GGML_UNUSED(device);
2844
+ GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2845
+ std::string(message).c_str());
2846
+ });
2847
+ dev_desc.SetUncapturedErrorCallback(
2848
+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
2849
+ GGML_UNUSED(device);
2850
+ GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2851
+ std::string(message).c_str());
2852
+ });
2853
+
2854
+ #ifndef __EMSCRIPTEN__
2855
+ // Enable Dawn-specific toggles to increase native performance
2856
+ // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2857
+ // only for native performance?
2858
+ const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2859
+ "disable_polyfills_on_integer_div_and_mod" };
2860
+ const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2861
+ wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2862
+ deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2863
+ deviceTogglesDesc.enabledToggleCount = 4;
2864
+ deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2865
+ deviceTogglesDesc.disabledToggleCount = 1;
2866
+
2867
+ dev_desc.nextInChain = &deviceTogglesDesc;
2868
+ #endif
2869
+
2870
+ ctx->webgpu_global_ctx->instance.WaitAny(
2871
+ ctx->webgpu_global_ctx->adapter.RequestDevice(
2872
+ &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
2873
+ [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
2874
+ if (status != wgpu::RequestDeviceStatus::Success) {
2875
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
2876
+ return;
2877
+ }
2878
+ ctx->webgpu_global_ctx->device = std::move(device);
2879
+ }),
2880
+ UINT64_MAX);
2881
+ GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
2882
+
2883
+ ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
2884
+ ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2885
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2886
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2887
+ ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
2888
+
2889
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2890
+ // Initialize buffer pool for timestamp queries, used for profiling
2891
+ ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
2892
+ ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2893
+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2894
+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2895
+ #endif
2896
+
2897
+ GGML_LOG_INFO(
2898
+ "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
2899
+ "device_desc: %s\n",
2900
+ info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
2901
+ std::string(info.device).c_str(), std::string(info.description).c_str());
2902
+ return true;
1207
2903
  }
1208
2904
 
1209
- static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
2905
+ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
2906
+ ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
2907
+ webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
2908
+ webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
2909
+ webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
2910
+ webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2911
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2912
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
2913
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
2914
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
2915
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
2916
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
2917
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
2918
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
2919
+
2920
+ ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
2921
+ ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
2922
+ ggml_webgpu_init_rope_pipeline(webgpu_ctx);
2923
+ ggml_webgpu_init_glu_pipeline(webgpu_ctx);
2924
+ ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
2925
+ #ifdef GGML_WEBGPU_DEBUG
2926
+ // Initialize debug buffers
2927
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
2928
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
2929
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
2930
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
2931
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
2932
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
2933
+ #endif
2934
+ return webgpu_ctx;
2935
+ }
2936
+
2937
+ static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
1210
2938
  GGML_UNUSED(params);
1211
2939
 
1212
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
2940
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
1213
2941
 
1214
- ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1215
- webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
2942
+ ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1216
2943
 
1217
- static ggml_backend_webgpu_context backend_ctx;
1218
- backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
1219
- backend_ctx.webgpu_ctx = webgpu_ctx;
2944
+ auto * backend_ctx = new ggml_backend_webgpu_context();
2945
+ backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
2946
+ backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
1220
2947
 
1221
2948
  // See GGML Backend Interface section
1222
- static ggml_backend backend = {
2949
+ auto * backend = new ggml_backend();
2950
+ *backend = {
1223
2951
  /* .guid = */ ggml_backend_webgpu_guid(),
1224
2952
  /* .interface = */ ggml_backend_webgpu_i,
1225
2953
  /* .device = */ dev,
1226
- /* .context = */ &backend_ctx,
2954
+ /* .context = */ backend_ctx,
1227
2955
  };
1228
-
1229
- return &backend;
2956
+ return backend;
1230
2957
  }
1231
2958
 
1232
2959
  static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
1233
2960
  // See GGML Backend Buffer Type Interface section
2961
+
1234
2962
  static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
1235
2963
  /* .iface = */ {
1236
2964
  /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
1237
- /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
1238
- /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
1239
- /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
1240
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1241
- /* .is_host = */ NULL, // defaults to false
2965
+ /* .alloc_buffer = */
2966
+ ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
2967
+ ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
2968
+ ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
2969
+ ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
1242
2970
  },
1243
2971
  /* .device = */
1244
2972
  dev,
1245
- /* .context = */ NULL,
2973
+ /* .context = */
2974
+ NULL
1246
2975
  };
1247
2976
 
1248
2977
  return &ggml_backend_webgpu_buffer_type;
@@ -1283,14 +3012,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
1283
3012
  static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
1284
3013
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1285
3014
 
1286
- webgpu_context webgpu_ctx = ctx->webgpu_ctx;
1287
-
1288
3015
  ggml_tensor * src0 = op->src[0];
1289
3016
  ggml_tensor * src1 = op->src[1];
3017
+ ggml_tensor * src2 = op->src[2];
3018
+
1290
3019
  // on smaller devices (or CI), tensors may be larger than the max storage buffer size
1291
- if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
1292
- (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
1293
- (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
3020
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3021
+ (src0 != nullptr &&
3022
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3023
+ (src1 != nullptr &&
3024
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
1294
3025
  return false;
1295
3026
  }
1296
3027
 
@@ -1304,28 +3035,43 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
1304
3035
  supports_op = true;
1305
3036
  break;
1306
3037
  case GGML_OP_ADD:
3038
+ case GGML_OP_SUB:
1307
3039
  case GGML_OP_MUL:
1308
- supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
1309
- (op->src[1]->type == op->type);
3040
+ case GGML_OP_DIV:
3041
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
3042
+ (src1->type == op->type);
3043
+ break;
3044
+ case GGML_OP_CONCAT:
3045
+ supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
3046
+ break;
3047
+ case GGML_OP_REPEAT:
3048
+ supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
1310
3049
  break;
1311
3050
  case GGML_OP_CPY:
3051
+ case GGML_OP_CONT:
3052
+ supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3053
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
3054
+ (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
3055
+ break;
1312
3056
  case GGML_OP_SET_ROWS:
1313
- supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
3057
+ supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
3058
+ (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
1314
3059
  break;
1315
3060
  case GGML_OP_GET_ROWS:
1316
- if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
1317
- op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
3061
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
1318
3062
  supports_op = (op->type == GGML_TYPE_F32);
3063
+ } else if (src0->type == GGML_TYPE_I32) {
3064
+ supports_op = op->type == GGML_TYPE_I32;
1319
3065
  }
1320
3066
  break;
1321
3067
  case GGML_OP_MUL_MAT:
1322
3068
  {
1323
- switch (op->src[1]->type) {
3069
+ switch (src1->type) {
1324
3070
  case GGML_TYPE_F16:
1325
- supports_op = (op->src[0]->type == GGML_TYPE_F16);
3071
+ supports_op |= (src0->type == GGML_TYPE_F16);
1326
3072
  break;
1327
3073
  case GGML_TYPE_F32:
1328
- switch (op->src[0]->type) {
3074
+ switch (src0->type) {
1329
3075
  case GGML_TYPE_F32:
1330
3076
  case GGML_TYPE_F16:
1331
3077
  case GGML_TYPE_Q4_0:
@@ -1357,19 +3103,160 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
1357
3103
  }
1358
3104
  break;
1359
3105
  }
3106
+ case GGML_OP_FLASH_ATTN_EXT:
3107
+ {
3108
+ #ifndef __EMSCRIPTEN__
3109
+ if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
3110
+ break;
3111
+ }
3112
+ // Head dimensions must fit in workgroup memory with minimum tile sizes
3113
+ size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
3114
+ const bool has_mask = op->src[3] != nullptr;
3115
+ const bool kv_direct = src1->type == GGML_TYPE_F16 &&
3116
+ (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
3117
+ (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
3118
+ const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
3119
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
3120
+ (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
3121
+ if (min_bytes > limit_bytes) {
3122
+ break;
3123
+ }
3124
+
3125
+ supports_op = src0->type == GGML_TYPE_F32 &&
3126
+ (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
3127
+ src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
3128
+ src2->type == src1->type && op->type == GGML_TYPE_F32;
3129
+ #endif
3130
+ break;
3131
+ }
1360
3132
  case GGML_OP_RMS_NORM:
1361
- supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
3133
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3134
+ break;
3135
+ case GGML_OP_ROPE:
3136
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
3137
+ break;
3138
+ case GGML_OP_GLU:
3139
+ switch (ggml_get_glu_op(op)) {
3140
+ case GGML_GLU_OP_REGLU:
3141
+ case GGML_GLU_OP_GEGLU:
3142
+ case GGML_GLU_OP_SWIGLU:
3143
+ case GGML_GLU_OP_GEGLU_ERF:
3144
+ case GGML_GLU_OP_GEGLU_QUICK:
3145
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
3146
+ break;
3147
+ case GGML_GLU_OP_SWIGLU_OAI:
3148
+ supports_op = op->type == GGML_TYPE_F32;
3149
+ break;
3150
+ default:
3151
+ break;
3152
+ }
3153
+ break;
3154
+ case GGML_OP_SCALE:
3155
+ supports_op = op->type == GGML_TYPE_F32;
3156
+ break;
3157
+ case GGML_OP_SOFT_MAX:
3158
+ supports_op = op->type == GGML_TYPE_F32;
3159
+ break;
3160
+ case GGML_OP_UNARY:
3161
+ {
3162
+ const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
3163
+
3164
+ switch (UNARY_OP) {
3165
+ case GGML_UNARY_OP_ABS:
3166
+ case GGML_UNARY_OP_SGN:
3167
+ case GGML_UNARY_OP_NEG:
3168
+ case GGML_UNARY_OP_STEP:
3169
+ case GGML_UNARY_OP_TANH:
3170
+ case GGML_UNARY_OP_ELU:
3171
+ case GGML_UNARY_OP_RELU:
3172
+ case GGML_UNARY_OP_SIGMOID:
3173
+ case GGML_UNARY_OP_GELU:
3174
+ case GGML_UNARY_OP_GELU_QUICK:
3175
+ case GGML_UNARY_OP_SILU:
3176
+ case GGML_UNARY_OP_HARDSWISH:
3177
+ case GGML_UNARY_OP_HARDSIGMOID:
3178
+ case GGML_UNARY_OP_EXP:
3179
+ case GGML_UNARY_OP_GELU_ERF:
3180
+ case GGML_UNARY_OP_SOFTPLUS:
3181
+ case GGML_UNARY_OP_EXPM1:
3182
+ case GGML_UNARY_OP_FLOOR:
3183
+ case GGML_UNARY_OP_CEIL:
3184
+ case GGML_UNARY_OP_ROUND:
3185
+ case GGML_UNARY_OP_TRUNC:
3186
+ case GGML_UNARY_OP_XIELU:
3187
+ supports_op =
3188
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3189
+ break;
3190
+ default:
3191
+ break;
3192
+ }
3193
+ }
3194
+ break;
3195
+ case GGML_OP_CLAMP:
3196
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3197
+ break;
3198
+ case GGML_OP_FILL:
3199
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3200
+ break;
3201
+ case GGML_OP_LOG:
3202
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3203
+ break;
3204
+ case GGML_OP_SQR:
3205
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3206
+ break;
3207
+ case GGML_OP_SQRT:
3208
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3209
+ break;
3210
+ case GGML_OP_SIN:
3211
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3212
+ break;
3213
+ case GGML_OP_COS:
3214
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3215
+ break;
3216
+ case GGML_OP_PAD:
3217
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3218
+ break;
3219
+ case GGML_OP_ARGMAX:
3220
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
3221
+ break;
3222
+ case GGML_OP_ARGSORT:
3223
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3224
+ break;
3225
+ case GGML_OP_TOP_K:
3226
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3227
+ break;
3228
+ case GGML_OP_CUMSUM:
3229
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
3230
+ break;
3231
+ case GGML_OP_SUM:
3232
+ case GGML_OP_SUM_ROWS:
3233
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
1362
3234
  break;
1363
3235
  default:
1364
3236
  break;
1365
3237
  }
1366
- #ifdef GGML_WEBGPU_DEBUG
3238
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3239
+ (src0 != nullptr &&
3240
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3241
+ (src1 != nullptr &&
3242
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3243
+ (src2 != nullptr &&
3244
+ ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
3245
+ supports_op = false;
3246
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
3247
+ }
3248
+
1367
3249
  if (!supports_op) {
1368
- WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
1369
- << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
1370
- << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3250
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
3251
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3252
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3253
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3254
+ } else {
3255
+ WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
3256
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3257
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3258
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
1371
3259
  }
1372
- #endif
1373
3260
  return supports_op;
1374
3261
  }
1375
3262
 
@@ -1379,7 +3266,7 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
1379
3266
  /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
1380
3267
  /* .get_type = */ ggml_backend_webgpu_device_get_type,
1381
3268
  /* .get_props = */ ggml_backend_webgpu_device_get_props,
1382
- /* .init_backend = */ ggml_backend_webgpu_device_init,
3269
+ /* .init_backend = */ ggml_backend_webgpu_backend_init,
1383
3270
  /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
1384
3271
  /* .get_host_buffer_type = */ NULL,
1385
3272
  /* .buffer_from_host_ptr = */ NULL,
@@ -1405,113 +3292,29 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
1405
3292
  return ctx->device_count;
1406
3293
  }
1407
3294
 
1408
- // TODO: Does this need to be thread safe? Is it only called once?
1409
3295
  // Only one device is supported for now
1410
3296
  static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1411
3297
  GGML_ASSERT(index == 0);
1412
3298
  WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
1413
3299
 
1414
- ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
1415
-
1416
- webgpu_context ctx = reg_ctx->webgpu_ctx;
1417
-
1418
- wgpu::RequestAdapterOptions options = {};
1419
- ctx->instance.WaitAny(ctx->instance.RequestAdapter(
1420
- &options, wgpu::CallbackMode::AllowSpontaneous,
1421
- [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
1422
- if (status != wgpu::RequestAdapterStatus::Success) {
1423
- GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
1424
- return;
1425
- }
1426
- ctx->adapter = std::move(adapter);
1427
- }),
1428
- UINT64_MAX);
1429
- GGML_ASSERT(ctx->adapter != nullptr);
1430
-
1431
- ctx->adapter.GetLimits(&ctx->limits);
1432
- ctx->max_wg_size_x = 288; // default value
1433
-
1434
- wgpu::AdapterInfo info{};
1435
- ctx->adapter.GetInfo(&info);
3300
+ WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
1436
3301
 
1437
- // Initialize device
1438
- std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
1439
- wgpu::FeatureName::ImplicitDeviceSynchronization };
1440
- wgpu::DeviceDescriptor dev_desc;
1441
- dev_desc.requiredLimits = &ctx->limits;
1442
- dev_desc.requiredFeatures = required_features.data();
1443
- dev_desc.requiredFeatureCount = required_features.size();
1444
- dev_desc.SetDeviceLostCallback(
1445
- wgpu::CallbackMode::AllowSpontaneous,
1446
- [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
1447
- GGML_UNUSED(device);
1448
- GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
1449
- std::string(message).c_str());
1450
- });
1451
- dev_desc.SetUncapturedErrorCallback(
1452
- [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
1453
- GGML_UNUSED(device);
1454
- GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
1455
- std::string(message).c_str());
1456
- });
1457
- ctx->instance.WaitAny(ctx->adapter.RequestDevice(
1458
- &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
1459
- [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
1460
- if (status != wgpu::RequestDeviceStatus::Success) {
1461
- GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
1462
- std::string(message).c_str());
1463
- return;
1464
- }
1465
- ctx->device = std::move(device);
1466
- }),
1467
- UINT64_MAX);
1468
- GGML_ASSERT(ctx->device != nullptr);
1469
-
1470
- // Initialize (compute) queue
1471
- ctx->queue = ctx->device.GetQueue();
1472
-
1473
- // Create buffer pool for shader parameters
1474
- ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
1475
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
1476
- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
1477
- ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
1478
- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
1479
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
1480
-
1481
- ggml_webgpu_init_memset_pipeline(ctx);
1482
- ggml_webgpu_init_mul_mat_pipeline(ctx);
1483
- ggml_webgpu_init_set_rows_pipeline(ctx);
1484
- ggml_webgpu_init_get_rows_pipeline(ctx);
1485
- ggml_webgpu_init_cpy_pipeline(ctx);
1486
- ggml_webgpu_init_add_pipeline(ctx);
1487
- ggml_webgpu_init_mul_pipeline(ctx);
1488
- ggml_webgpu_init_rms_norm_pipeline(ctx);
3302
+ ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
1489
3303
 
1490
- #ifdef GGML_WEBGPU_DEBUG
1491
- // Initialize debug buffers
1492
- ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1493
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
1494
- ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1495
- wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
1496
- #endif
3304
+ create_webgpu_device(reg_ctx);
1497
3305
 
1498
3306
  static ggml_backend_webgpu_device_context device_ctx;
1499
- device_ctx.webgpu_ctx = ctx;
1500
- device_ctx.device_name = GGML_WEBGPU_NAME;
1501
- device_ctx.device_desc = info.description;
1502
-
1503
- GGML_LOG_INFO(
1504
- "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
1505
- "device_desc: %s\n",
1506
- info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
1507
- std::string(info.device).c_str(), std::string(info.description).c_str());
1508
-
3307
+ device_ctx.device_name = GGML_WEBGPU_NAME;
3308
+ device_ctx.device_desc = GGML_WEBGPU_NAME;
3309
+ device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx;
1509
3310
  // See GGML Backend Device Interface section
1510
3311
  static ggml_backend_device device = {
1511
3312
  /* .iface = */ ggml_backend_webgpu_device_i,
1512
3313
  /* .reg = */ reg,
1513
3314
  /* .context = */ &device_ctx,
1514
3315
  };
3316
+
3317
+ WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
1515
3318
  return &device;
1516
3319
  }
1517
3320
 
@@ -1527,10 +3330,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
1527
3330
  ggml_backend_reg_t ggml_backend_webgpu_reg() {
1528
3331
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
1529
3332
 
1530
- webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
1531
-
1532
3333
  static ggml_backend_webgpu_reg_context ctx;
1533
- ctx.webgpu_ctx = webgpu_ctx;
1534
3334
  ctx.name = GGML_WEBGPU_NAME;
1535
3335
  ctx.device_count = 1;
1536
3336
 
@@ -1538,8 +3338,26 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
1538
3338
  std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
1539
3339
  instance_descriptor.requiredFeatures = instance_features.data();
1540
3340
  instance_descriptor.requiredFeatureCount = instance_features.size();
1541
- webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
1542
- GGML_ASSERT(webgpu_ctx->instance != nullptr);
3341
+
3342
+ #ifndef __EMSCRIPTEN__
3343
+ const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
3344
+ wgpu::DawnTogglesDescriptor instanceTogglesDesc;
3345
+ instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
3346
+ instanceTogglesDesc.enabledToggleCount = 1;
3347
+ instance_descriptor.nextInChain = &instanceTogglesDesc;
3348
+ #endif
3349
+
3350
+ wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor);
3351
+ ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
3352
+ ctx.webgpu_global_ctx->instance = std::move(inst);
3353
+
3354
+ #ifdef __EMSCRIPTEN__
3355
+ if (ctx.webgpu_global_ctx->instance == nullptr) {
3356
+ GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
3357
+ return nullptr;
3358
+ }
3359
+ #endif
3360
+ GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
1543
3361
 
1544
3362
  static ggml_backend_reg reg = {
1545
3363
  /* .api_version = */ GGML_BACKEND_API_VERSION,
@@ -1552,7 +3370,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
1552
3370
  ggml_backend_t ggml_backend_webgpu_init(void) {
1553
3371
  ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
1554
3372
 
1555
- return ggml_backend_webgpu_device_init(dev, nullptr);
3373
+ return ggml_backend_webgpu_backend_init(dev, nullptr);
1556
3374
  }
1557
3375
 
1558
3376
  GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)