whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -7,34 +7,110 @@
7
7
 
8
8
  #include "ggml-backend-impl.h"
9
9
  #include "ggml-impl.h"
10
+ #include "ggml-webgpu-shader-lib.hpp"
10
11
  #include "ggml-wgsl-shaders.hpp"
12
+ #include "pre_wgsl.hpp"
13
+
14
+ #ifdef __EMSCRIPTEN__
15
+ # include <emscripten/emscripten.h>
16
+ #endif
11
17
 
12
18
  #include <webgpu/webgpu_cpp.h>
13
19
 
20
+ #include <atomic>
14
21
  #include <condition_variable>
22
+ #include <cstdint>
15
23
  #include <cstring>
16
24
  #include <iostream>
25
+ #include <map>
17
26
  #include <mutex>
27
+ #include <optional>
18
28
  #include <string>
19
29
  #include <vector>
20
30
 
31
+ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
32
+ #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
33
+
21
34
  #ifdef GGML_WEBGPU_DEBUG
22
35
  # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
23
- # define WEBGPU_DEBUG_BUF_ELEMS 32
36
+ # define WEBGPU_DEBUG_BUF_ELEMS 512
24
37
  #else
25
38
  # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
26
39
  #endif // GGML_WEBGPU_DEBUG
27
40
 
41
+ #ifdef GGML_WEBGPU_CPU_PROFILE
42
+ // total timing (aggregated)
43
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
44
+
45
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
46
+ auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
47
+ double cpu_total_time_##id = \
48
+ std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
49
+ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
50
+
51
+ // fine-grained timing (not included in totals)
52
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
53
+
54
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
55
+ auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
56
+ double cpu_detail_time_##id = \
57
+ std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
58
+ (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
59
+ #else
60
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id)
61
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
62
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id)
63
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
64
+ #endif // GGML_WEBGPU_CPU_PROFILE
65
+
66
+ #ifdef GGML_WEBGPU_GPU_PROFILE
67
+ # define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
68
+ # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
69
+ #endif
70
+
28
71
  /* Constants */
29
72
 
30
- #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31
- #define WEBGPU_MUL_MAT_WG_SIZE 64
32
- #define WEBGPU_NUM_PARAM_BUFS 100
73
+ // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
74
+ #define WEBGPU_MAX_WG_SIZE 288
75
+
76
+ #define WEBGPU_MUL_MAT_WG_SIZE 256
77
+ #define WEBGPU_NUM_PARAM_BUFS 32u
78
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
79
+ #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
80
+ // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
81
+ #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
33
82
  #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34
83
  #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35
84
  #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36
85
  #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
37
86
 
87
+ // For operations which process a row in parallel, this seems like a reasonable default
88
+ #define WEBGPU_ROW_SPLIT_WG_SIZE 64
89
+
90
+ // Matrix multiplication parameters
91
+
92
+ // Register tiling parameters
93
+ #define WEBGPU_MUL_MAT_TILE_M 8
94
+ #define WEBGPU_MUL_MAT_TILE_N 8
95
+ #define WEBGPU_MUL_MAT_WG_SIZE_M 8
96
+ #define WEBGPU_MUL_MAT_WG_SIZE_N 8
97
+ #define WEBGPU_MUL_MAT_TILE_K 32
98
+
99
+ // Subgroup matrix parameters
100
+ // The number of subgroups in the M dimension
101
+ #define WEBGPU_MUL_MAT_SUBGROUP_M 2
102
+ // The number of subgroups in the N dimension
103
+ #define WEBGPU_MUL_MAT_SUBGROUP_N 2
104
+ // The number of subgroup matrices each subgroup accumulates over
105
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
106
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
107
+
108
+ // Matrix-vector multiplication parameters
109
+ #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
110
+ // Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
111
+ #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
112
+ #define WEBGPU_MUL_MAT_VEC_TILE_K 256
113
+
38
114
  /* End Constants */
39
115
 
40
116
  // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
@@ -62,6 +138,11 @@ struct webgpu_pool_bufs {
62
138
  wgpu::Buffer dev_buf;
63
139
  };
64
140
 
141
+ // The futures to wait on for a single queue submission
142
+ struct webgpu_submission_futures {
143
+ std::vector<wgpu::FutureWaitInfo> futures;
144
+ };
145
+
65
146
  // Holds a pool of parameter buffers for WebGPU operations
66
147
  struct webgpu_buf_pool {
67
148
  std::vector<webgpu_pool_bufs> free;
@@ -108,6 +189,124 @@ struct webgpu_buf_pool {
108
189
  }
109
190
  };
110
191
 
192
+ #ifdef GGML_WEBGPU_GPU_PROFILE
193
+ struct webgpu_gpu_profile_bufs {
194
+ wgpu::Buffer host_buf;
195
+ wgpu::Buffer dev_buf;
196
+ wgpu::QuerySet query_set;
197
+ };
198
+
199
+ // Holds a pool of parameter buffers for WebGPU operations
200
+ struct webgpu_gpu_profile_buf_pool {
201
+ std::vector<webgpu_gpu_profile_bufs> free;
202
+
203
+ std::mutex mutex;
204
+
205
+ std::condition_variable cv;
206
+
207
+ void init(wgpu::Device device,
208
+ int num_bufs,
209
+ size_t buf_size,
210
+ wgpu::BufferUsage dev_buf_usage,
211
+ wgpu::BufferUsage host_buf_usage) {
212
+ for (int i = 0; i < num_bufs; i++) {
213
+ wgpu::Buffer host_buf;
214
+ wgpu::Buffer dev_buf;
215
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
216
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
217
+ // Create a query set for 2 timestamps
218
+ wgpu::QuerySetDescriptor ts_query_set_desc = {};
219
+
220
+ ts_query_set_desc.type = wgpu::QueryType::Timestamp;
221
+ ts_query_set_desc.count = 2;
222
+ wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
223
+
224
+ free.push_back({ host_buf, dev_buf, ts_query_set });
225
+ }
226
+ }
227
+
228
+ webgpu_gpu_profile_bufs alloc_bufs() {
229
+ std::unique_lock<std::mutex> lock(mutex);
230
+ cv.wait(lock, [this] { return !free.empty(); });
231
+ webgpu_gpu_profile_bufs bufs = free.back();
232
+ free.pop_back();
233
+ return bufs;
234
+ }
235
+
236
+ void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
237
+ std::lock_guard<std::mutex> lock(mutex);
238
+ free.insert(free.end(), bufs.begin(), bufs.end());
239
+ cv.notify_all();
240
+ }
241
+
242
+ void cleanup() {
243
+ std::lock_guard<std::mutex> lock(mutex);
244
+ for (auto & bufs : free) {
245
+ bufs.host_buf.Destroy();
246
+ bufs.dev_buf.Destroy();
247
+ bufs.query_set.Destroy();
248
+ }
249
+ free.clear();
250
+ }
251
+ };
252
+ #endif
253
+
254
+ struct webgpu_pipeline {
255
+ wgpu::ComputePipeline pipeline;
256
+ std::string name;
257
+ void * context = nullptr;
258
+ };
259
+
260
+ struct webgpu_command {
261
+ wgpu::CommandBuffer commands;
262
+ webgpu_pool_bufs params_bufs;
263
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs;
264
+ #ifdef GGML_WEBGPU_GPU_PROFILE
265
+ webgpu_gpu_profile_bufs timestamp_query_bufs;
266
+ std::string pipeline_name;
267
+ #endif
268
+ };
269
+
270
+ struct flash_attn_pipeline_key {
271
+ int q_type;
272
+ int kv_type;
273
+ int dst_type;
274
+ uint32_t head_dim_qk;
275
+ uint32_t head_dim_v;
276
+ bool kv_direct;
277
+ bool has_mask;
278
+ bool has_sinks;
279
+ bool uses_logit_softcap;
280
+
281
+ bool operator==(const flash_attn_pipeline_key & other) const {
282
+ return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
283
+ head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
284
+ has_mask == other.has_mask && has_sinks == other.has_sinks &&
285
+ uses_logit_softcap == other.uses_logit_softcap;
286
+ }
287
+ };
288
+
289
+ // Same hash combine function as in boost
290
+ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
291
+ seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
292
+ }
293
+
294
+ struct flash_attn_pipeline_key_hash {
295
+ size_t operator()(const flash_attn_pipeline_key & key) const {
296
+ size_t seed = 0;
297
+ ggml_webgpu_hash_combine(seed, key.q_type);
298
+ ggml_webgpu_hash_combine(seed, key.kv_type);
299
+ ggml_webgpu_hash_combine(seed, key.dst_type);
300
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
301
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
302
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
303
+ ggml_webgpu_hash_combine(seed, key.has_mask);
304
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
305
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
306
+ return seed;
307
+ }
308
+ };
309
+
111
310
  // All the base objects needed to run operations on a WebGPU device
112
311
  struct webgpu_context_struct {
113
312
  wgpu::Instance instance;
@@ -116,47 +315,68 @@ struct webgpu_context_struct {
116
315
  wgpu::Queue queue;
117
316
  wgpu::Limits limits;
118
317
 
119
- // Separate this out from limits since on some Metal systems, the limit returned by
120
- // querying the limits is higher than the actual allowed maximum.
121
- uint32_t max_wg_size_x;
318
+ uint32_t max_subgroup_size;
319
+
320
+ bool supports_subgroup_matrix = false;
321
+ uint32_t sg_mat_m;
322
+ uint32_t sg_mat_n;
323
+ uint32_t sg_mat_k;
122
324
 
123
325
  std::recursive_mutex mutex;
326
+ std::atomic_uint inflight_threads = 0;
124
327
 
125
328
  webgpu_buf_pool param_buf_pool;
126
329
  webgpu_buf_pool set_rows_error_buf_pool;
127
330
 
128
- wgpu::ComputePipeline memset_pipeline;
129
- wgpu::ComputePipeline mul_mat_pipeline[30][2];
130
- wgpu::ComputePipeline set_rows_pipeline;
131
- wgpu::ComputePipeline get_rows_pipeline[30];
132
- wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
133
- wgpu::ComputePipeline cpy_pipeline;
134
- wgpu::ComputePipeline add_pipeline[2];
135
- wgpu::ComputePipeline add_ip_pipeline[2];
136
- wgpu::ComputePipeline mul_pipeline[2];
137
- wgpu::ComputePipeline mul_ip_pipeline[2];
138
- wgpu::ComputePipeline rms_norm_pipeline;
139
- wgpu::ComputePipeline rms_norm_ip_pipeline;
331
+ pre_wgsl::Preprocessor p;
140
332
 
141
- size_t memset_bytes_per_thread;
333
+ std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
142
334
 
143
- // Staging buffer for reading data from the GPU
144
- wgpu::Buffer get_tensor_staging_buf;
335
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
336
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
337
+ mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
338
+
339
+ std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
145
340
 
146
- // Command buffers which need to be submitted
147
- std::vector<wgpu::CommandBuffer> staged_command_bufs;
341
+ std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
342
+ std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
148
343
 
149
- // Parameter buffers associated with the staged command buffers
150
- std::vector<webgpu_pool_bufs> staged_param_bufs;
151
- // Buffers associated with set_rows operations, used to store potential errors
152
- std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
344
+ std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
345
+ std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
346
+ std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
347
+ std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
348
+ std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
153
349
 
154
- std::vector<wgpu::FutureWaitInfo> callback_futures;
350
+ std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
351
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
352
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
353
+ std::map<int, webgpu_pipeline> scale_pipelines; // inplace
354
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
355
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> unary_pipelines; // unary_op, type, inplace
356
+
357
+ size_t memset_bytes_per_thread;
358
+
359
+ // Staging buffer for reading data from the GPU
360
+ wgpu::Buffer get_tensor_staging_buf;
155
361
 
156
362
  #ifdef GGML_WEBGPU_DEBUG
157
363
  wgpu::Buffer debug_host_buf;
158
364
  wgpu::Buffer debug_dev_buf;
159
365
  #endif
366
+
367
+ #ifdef GGML_WEBGPU_CPU_PROFILE
368
+ // Profiling: labeled CPU time in ms (total)
369
+ std::unordered_map<std::string, double> cpu_time_ms;
370
+ // Profiling: detailed CPU time in ms
371
+ std::unordered_map<std::string, double> cpu_detail_ms;
372
+ #endif
373
+
374
+ #ifdef GGML_WEBGPU_GPU_PROFILE
375
+ // Profiling: per-shader GPU time in ms
376
+ std::unordered_map<std::string, double> shader_gpu_time_ms;
377
+ // Profiling: pool of timestamp query buffers (one per operation)
378
+ webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
379
+ #endif
160
380
  };
161
381
 
162
382
  typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -181,23 +401,39 @@ struct ggml_backend_webgpu_context {
181
401
  struct ggml_backend_webgpu_buffer_context {
182
402
  webgpu_context webgpu_ctx;
183
403
  wgpu::Buffer buffer;
404
+ std::string label;
184
405
 
185
- ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
406
+ ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
186
407
  webgpu_ctx(std::move(ctx)),
187
- buffer(std::move(buf)) {}
408
+ buffer(std::move(buf)),
409
+ label(std::move(lbl)) {}
188
410
  };
189
411
 
190
- /* End struct definitions */
191
-
192
412
  /* WebGPU object initializations */
193
413
 
194
- static void ggml_webgpu_create_pipeline(wgpu::Device & device,
195
- wgpu::ComputePipeline & pipeline,
196
- const char * shader_code,
197
- const char * label,
198
- const std::vector<wgpu::ConstantEntry> & constants = {}) {
199
- WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
414
+ // Process a WGSL shader string, replacing tokens of the form {{KEY}} with
415
+ // the corresponding values provided in `repls`.
416
+ static std::string ggml_webgpu_process_shader_repls(const char * src,
417
+ const std::map<std::string, std::string> & repls) {
418
+ if (!src) {
419
+ return std::string();
420
+ }
421
+ std::string s = src;
422
+ for (const auto & kv : repls) {
423
+ std::string token = "{{" + kv.first + "}}";
424
+ size_t pos = 0;
425
+ while ((pos = s.find(token, pos)) != std::string::npos) {
426
+ s.replace(pos, token.length(), kv.second);
427
+ pos += kv.second.length();
428
+ }
429
+ }
430
+ return s;
431
+ }
200
432
 
433
+ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
434
+ const char * shader_code,
435
+ const char * label,
436
+ const std::vector<wgpu::ConstantEntry> & constants = {}) {
201
437
  wgpu::ShaderSourceWGSL shader_source;
202
438
  shader_source.code = shader_code;
203
439
 
@@ -215,7 +451,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
215
451
  pipeline_desc.compute.constants = constants.data();
216
452
  pipeline_desc.compute.constantCount = constants.size();
217
453
  }
218
- pipeline = device.CreateComputePipeline(&pipeline_desc);
454
+ return { device.CreateComputePipeline(&pipeline_desc), label };
219
455
  }
220
456
 
221
457
  static void ggml_webgpu_create_buffer(wgpu::Device & device,
@@ -223,8 +459,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
223
459
  size_t size,
224
460
  wgpu::BufferUsage usage,
225
461
  const char * label) {
226
- WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
227
-
228
462
  wgpu::BufferDescriptor buffer_desc;
229
463
  buffer_desc.size = size;
230
464
  buffer_desc.usage = usage;
@@ -240,79 +474,35 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
240
474
  /** WebGPU Actions */
241
475
 
242
476
  // Wait for the queue to finish processing all submitted work
243
- static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
244
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
245
- if (ctx->callback_futures.empty()) {
246
- // no existing callbacks, wait on queue submission
247
- ctx->instance.WaitAny(
248
- ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
249
- [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
250
- if (status != wgpu::QueueWorkDoneStatus::Success) {
251
- GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
252
- std::string(message).c_str());
253
- }
254
- }),
255
- UINT64_MAX);
256
- } else {
257
- // existing callbacks, wait on them
258
- ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
259
- ctx->callback_futures.clear();
477
+ static void ggml_backend_webgpu_wait(webgpu_context & ctx,
478
+ std::vector<webgpu_submission_futures> & futures,
479
+ bool block = true) {
480
+ // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
481
+ // inflight_max may be 0, meaning that we must wait on all futures.
482
+ uint64_t timeout_ms = block ? UINT64_MAX : 0;
483
+ uint32_t inflight_threads = ctx->inflight_threads;
484
+ uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
485
+ while (futures.size() >= inflight_max && futures.size() > 0) {
486
+ ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
487
+ futures.erase(futures.begin());
260
488
  }
261
- }
262
-
263
- static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
264
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
265
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
266
- if (ctx->staged_command_bufs.empty()) {
267
- // Nothing to submit
268
- return;
269
- }
270
- ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
271
-
272
- // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
273
- if (ctx->staged_set_row_error_bufs.size() > 0) {
274
- wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
275
- for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
276
- // Copy the error buffer to the host buffer
277
- encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
489
+ size_t i = 0;
490
+ while (i < futures.size()) {
491
+ auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
492
+ switch (waitStatus) {
493
+ case wgpu::WaitStatus::Success:
494
+ futures.erase(futures.begin() + i);
495
+ break;
496
+ case wgpu::WaitStatus::TimedOut:
497
+ i++;
498
+ break;
499
+ case wgpu::WaitStatus::Error:
500
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
501
+ break;
502
+ default:
503
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
504
+ break;
278
505
  }
279
- wgpu::CommandBuffer commands = encoder.Finish();
280
- ctx->queue.Submit(1, &commands);
281
- }
282
-
283
- ctx->staged_command_bufs.clear();
284
- std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
285
- std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);
286
-
287
- // Free the staged parameter buffers once the submission completes
288
- wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
289
- wgpu::CallbackMode::AllowSpontaneous,
290
- [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
291
- if (status != wgpu::QueueWorkDoneStatus::Success) {
292
- GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
293
- }
294
- // Free the staged buffers
295
- ctx->param_buf_pool.free_bufs(staged_param_bufs);
296
- });
297
- ctx->callback_futures.push_back({ p_f });
298
-
299
- // Check for errrors in SET_ROWS operations
300
- for (auto & error_bufs : staged_set_row_error_bufs) {
301
- wgpu::Future f = error_bufs.host_buf.MapAsync(
302
- wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
303
- [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
304
- if (status != wgpu::MapAsyncStatus::Success) {
305
- GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
306
- } else {
307
- const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
308
- if (*error_data) {
309
- GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
310
- }
311
- // We can't unmap in here due to WebGPU reentrancy limitations.
312
- ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
313
- }
314
- });
315
- ctx->callback_futures.push_back({ f });
316
506
  }
317
507
  }
318
508
 
@@ -336,30 +526,97 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
336
526
  // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
337
527
  // debug statements in the shader, and then call this function after encoding the commands and submitting them.
338
528
  static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
339
- ggml_backend_webgpu_submit_queue(ctx);
340
529
  wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
341
530
  encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
342
531
  wgpu::CommandBuffer commands = encoder.Finish();
343
532
  ctx->queue.Submit(1, &commands);
344
-
345
533
  ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
346
- const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
347
- std::cout << "debug data:";
348
- for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
349
- std::cout << " " << i << ": " << debug_data[i];
350
- }
351
- std::cout << "\n";
534
+ const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
535
+ std::cout << "debug[0]: " << debug_data[0] << "\n";
352
536
  ctx->debug_host_buf.Unmap();
353
537
  }
354
538
  #endif
355
539
 
356
- static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
357
- wgpu::ComputePipeline & pipeline,
358
- std::vector<uint32_t> params,
359
- std::vector<wgpu::BindGroupEntry> bind_group_entries,
360
- uint32_t wg_x,
361
- const char * bind_group_label = nullptr,
362
- bool submit_and_wait = false) {
540
+ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) {
541
+ std::vector<wgpu::CommandBuffer> command_buffers;
542
+ std::vector<webgpu_pool_bufs> params_bufs;
543
+ std::vector<webgpu_pool_bufs> set_rows_error_bufs;
544
+ #ifdef GGML_WEBGPU_GPU_PROFILE
545
+ std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
546
+ #endif
547
+
548
+ for (const auto & command : commands) {
549
+ command_buffers.push_back(command.commands);
550
+ params_bufs.push_back(command.params_bufs);
551
+ if (command.set_rows_error_bufs) {
552
+ set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
553
+ }
554
+ }
555
+ ctx->queue.Submit(command_buffers.size(), command_buffers.data());
556
+
557
+ std::vector<wgpu::FutureWaitInfo> futures;
558
+
559
+ wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
560
+ wgpu::CallbackMode::AllowSpontaneous,
561
+ [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
562
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
563
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
564
+ }
565
+ // Free the staged buffers
566
+ ctx->param_buf_pool.free_bufs({ params_bufs });
567
+ });
568
+ futures.push_back({ p_f });
569
+
570
+ for (const auto & bufs : set_rows_error_bufs) {
571
+ wgpu::Future f = bufs.host_buf.MapAsync(
572
+ wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
573
+ [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
574
+ if (status != wgpu::MapAsyncStatus::Success) {
575
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
576
+ } else {
577
+ const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
578
+ if (*error_data) {
579
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
580
+ }
581
+ // We can't unmap in here due to WebGPU reentrancy limitations.
582
+ ctx->set_rows_error_buf_pool.free_bufs({ bufs });
583
+ }
584
+ });
585
+ futures.push_back({ f });
586
+ }
587
+
588
+ #ifdef GGML_WEBGPU_GPU_PROFILE
589
+ for (const auto & command : commands) {
590
+ auto label = command.pipeline_name;
591
+ auto ts_bufs = command.timestamp_query_bufs;
592
+
593
+ wgpu::Future f = ts_bufs.host_buf.MapAsync(
594
+ wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
595
+ [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
596
+ if (status != wgpu::MapAsyncStatus::Success) {
597
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
598
+ } else {
599
+ const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
600
+ // WebGPU timestamps are in ns; convert to ms
601
+ double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
602
+ ctx->shader_gpu_time_ms[label] += elapsed_ms;
603
+ // We can't unmap in here due to WebGPU reentrancy limitations.
604
+ ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
605
+ }
606
+ });
607
+ futures.push_back({ f });
608
+ }
609
+ #endif
610
+ return { futures };
611
+ }
612
+
613
+ static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
614
+ webgpu_pipeline & pipeline,
615
+ std::vector<uint32_t> params,
616
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
617
+ uint32_t wg_x,
618
+ uint32_t wg_y = 1,
619
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
363
620
  webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
364
621
 
365
622
  ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
@@ -377,44 +634,58 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
377
634
  .size = params_bufs.dev_buf.GetSize() });
378
635
 
379
636
  wgpu::BindGroupDescriptor bind_group_desc;
380
- bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
637
+ bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0);
381
638
  bind_group_desc.entryCount = bind_group_entries.size();
382
639
  bind_group_desc.entries = bind_group_entries.data();
383
- if (bind_group_label) {
384
- bind_group_desc.label = bind_group_label;
385
- }
640
+ bind_group_desc.label = pipeline.name.c_str();
386
641
  wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
387
642
 
388
643
  wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
389
644
  encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
645
+
646
+ #ifdef GGML_WEBGPU_GPU_PROFILE
647
+ // --- Profiling: GPU timestamp queries ---
648
+ // Allocate a timestamp query buffer (2 timestamps: start/end)
649
+ webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
650
+ if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
651
+ ts_bufs.host_buf.Unmap();
652
+ }
653
+
654
+ wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
655
+ .beginningOfPassWriteIndex = 0,
656
+ .endOfPassWriteIndex = 1 };
657
+ wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
658
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
659
+ #else
390
660
  wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
391
- pass.SetPipeline(pipeline);
661
+ #endif
662
+ pass.SetPipeline(pipeline.pipeline);
392
663
  pass.SetBindGroup(0, bind_group);
393
- pass.DispatchWorkgroups(wg_x, 1, 1);
664
+ pass.DispatchWorkgroups(wg_x, wg_y, 1);
394
665
  pass.End();
395
- wgpu::CommandBuffer commands = encoder.Finish();
396
- if (submit_and_wait) {
397
- // Submit and wait immediately
398
- ctx->queue.Submit(1, &commands);
399
- ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
400
- wgpu::CallbackMode::AllowSpontaneous,
401
- [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
402
- if (status != wgpu::QueueWorkDoneStatus::Success) {
403
- GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
404
- }
405
- ctx->param_buf_pool.free_bufs({ params_bufs });
406
- }),
407
- UINT64_MAX);
408
- } else {
409
- // Lock the context mutex when pushing to the staging vectors.
410
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
411
- // Enqueue commands and only submit if we have enough staged commands
412
- ctx->staged_command_bufs.push_back(commands);
413
- ctx->staged_param_bufs.push_back(params_bufs);
414
- if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
415
- ggml_backend_webgpu_submit_queue(ctx);
416
- }
666
+
667
+ #ifdef GGML_WEBGPU_GPU_PROFILE
668
+ // Resolve the query set into the device buffer
669
+ encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
670
+ encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
671
+ #endif
672
+
673
+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
674
+ if (set_rows_error_bufs) {
675
+ encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
676
+ set_rows_error_bufs->host_buf.GetSize());
417
677
  }
678
+
679
+ wgpu::CommandBuffer commands = encoder.Finish();
680
+ webgpu_command result = {};
681
+ result.commands = commands;
682
+ result.params_bufs = params_bufs;
683
+ result.set_rows_error_bufs = set_rows_error_bufs;
684
+ #ifdef GGML_WEBGPU_GPU_PROFILE
685
+ result.timestamp_query_bufs = ts_bufs;
686
+ result.pipeline_name = pipeline.name;
687
+ #endif
688
+ return result;
418
689
  }
419
690
 
420
691
  static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
@@ -426,9 +697,12 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
426
697
  std::vector<wgpu::BindGroupEntry> entries = {
427
698
  { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
428
699
  };
429
- size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread;
430
- uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
431
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
700
+ size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread;
701
+ uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
702
+
703
+ webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x);
704
+ std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
705
+ ggml_backend_webgpu_wait(ctx, futures);
432
706
  }
433
707
 
434
708
  /** End WebGPU Actions */
@@ -440,12 +714,53 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
440
714
  return ctx->name.c_str();
441
715
  }
442
716
 
717
+ // TODO: implement proper cleanup
443
718
  static void ggml_backend_webgpu_free(ggml_backend_t backend) {
444
719
  ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
445
720
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
446
721
 
447
- // TODO: cleanup
722
+ #ifdef GGML_WEBGPU_CPU_PROFILE
723
+ std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
724
+ double total_cpu = 0.0;
725
+ for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
726
+ total_cpu += kv.second;
727
+ }
728
+ std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
729
+ std::cout << "ggml_webgpu: cpu breakdown:\n";
730
+ for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
731
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
732
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
733
+ }
734
+ if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) {
735
+ std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
736
+ }
737
+ for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) {
738
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
739
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
740
+ }
741
+ #endif
742
+
743
+ #ifdef GGML_WEBGPU_GPU_PROFILE
744
+ std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
745
+ double total_gpu = 0.0;
746
+ for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
747
+ total_gpu += kv.second;
748
+ }
749
+ std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
750
+ std::cout << "\nggml_webgpu: gpu breakdown:\n";
751
+ for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
752
+ double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
753
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
754
+ }
755
+ #endif
756
+
757
+ #if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
758
+ std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
759
+ #endif
760
+
761
+ #if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
448
762
  GGML_UNUSED(ctx);
763
+ #endif
449
764
  }
450
765
 
451
766
  static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
@@ -457,19 +772,18 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
457
772
  return ctx->buffer;
458
773
  }
459
774
 
460
- static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
775
+ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
461
776
  size_t offset = ggml_webgpu_tensor_offset(t);
462
777
  return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
463
778
  }
464
779
 
465
- static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
780
+ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
466
781
  size_t offset = ggml_webgpu_tensor_offset(t);
467
782
  return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
468
783
  }
469
784
 
470
785
  static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
471
- return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
472
- ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
786
+ return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
473
787
  }
474
788
 
475
789
  // Used to determine if two tensors are the same for in-place operations
@@ -478,7 +792,7 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
478
792
  (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
479
793
  }
480
794
 
481
- static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
795
+ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
482
796
  uint32_t ne = (uint32_t) ggml_nelements(dst);
483
797
 
484
798
  std::vector<uint32_t> params = {
@@ -489,8 +803,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
489
803
  (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
490
804
  (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
491
805
  (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
492
- // Logical shape — same for both tensors even if permuted
493
- (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
806
+ // Logical shapes
807
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
808
+ (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
494
809
  };
495
810
 
496
811
  std::vector<wgpu::BindGroupEntry> entries = {
@@ -504,15 +819,17 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
504
819
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
505
820
  };
506
821
 
507
- size_t max_wg_size = ctx->max_wg_size_x;
508
- uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
509
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
822
+ uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
823
+ return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
510
824
  }
511
825
 
512
- static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
826
+ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
827
+ ggml_tensor * src,
828
+ ggml_tensor * idx,
829
+ ggml_tensor * dst) {
513
830
  // For set rows specifically, we need to check if src and idx are empty tensors.
514
831
  if (ggml_is_empty(src) || ggml_is_empty(idx)) {
515
- return;
832
+ return std::nullopt;
516
833
  }
517
834
 
518
835
  webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
@@ -552,16 +869,24 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
552
869
  { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
553
870
  };
554
871
 
555
- size_t max_wg_size = ctx->max_wg_size_x;
556
- uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
872
+ int vectorized = src->ne[0] % 4 == 0;
873
+ webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized];
874
+ uint32_t threads;
875
+ if (vectorized) {
876
+ threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
877
+ } else {
878
+ threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
879
+ }
557
880
 
558
- std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
559
- ctx->staged_set_row_error_bufs.push_back(error_bufs);
881
+ uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE);
560
882
 
561
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
883
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
562
884
  }
563
885
 
564
- static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
886
+ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
887
+ ggml_tensor * src,
888
+ ggml_tensor * idx,
889
+ ggml_tensor * dst) {
565
890
  std::vector<uint32_t> params = {
566
891
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
567
892
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
@@ -593,23 +918,23 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
593
918
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
594
919
  };
595
920
 
596
- size_t max_wg_size = ctx->max_wg_size_x;
597
- uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size;
921
+ uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
598
922
 
599
- wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type];
600
- if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) {
601
- pipeline = ctx->get_rows_f32_no_vec_pipeline;
602
- }
603
- ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
923
+ uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
924
+ webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
925
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
604
926
  }
605
927
 
606
- static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
928
+ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
929
+ ggml_tensor * src0,
930
+ ggml_tensor * src1,
931
+ ggml_tensor * dst) {
607
932
  std::vector<uint32_t> params = {
608
933
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
609
934
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
610
935
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
611
- (uint32_t) dst->ne[1], // number of rows in result (M)
612
- (uint32_t) dst->ne[0], // number of columns in result (N)
936
+ (uint32_t) dst->ne[0], // number of rows in result (M, transposed)
937
+ (uint32_t) dst->ne[1], // number of columns in result (N)
613
938
  (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
614
939
  (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
615
940
  (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
@@ -638,18 +963,269 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
638
963
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
639
964
  };
640
965
 
641
- uint32_t wg_x =
642
- (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
643
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
644
- ggml_op_name(dst->op));
966
+ webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
967
+
968
+ uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
969
+ uint32_t wg_y = 1;
970
+
971
+ bool use_fast = false;
972
+ switch (src1->type) {
973
+ case GGML_TYPE_F16:
974
+ use_fast = (src0->type == GGML_TYPE_F16);
975
+ break;
976
+ case GGML_TYPE_F32:
977
+ switch (src0->type) {
978
+ case GGML_TYPE_F32:
979
+ case GGML_TYPE_F16:
980
+ case GGML_TYPE_Q4_0:
981
+ use_fast = true;
982
+ break;
983
+ default:
984
+ break;
985
+ }
986
+ break;
987
+ default:
988
+ break;
989
+ }
990
+
991
+ if (use_fast) {
992
+ int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
993
+ if (dst->ne[1] == 1) {
994
+ // We don't support vectorized mul_mat_vec for quantized types
995
+ vectorized = vectorized && (src0->type < 2);
996
+ pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
997
+ uint32_t batches = dst->ne[2] * dst->ne[3];
998
+ uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
999
+ uint32_t total_wg = output_groups * batches;
1000
+ wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
1001
+ wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension);
1002
+ } else {
1003
+ pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
1004
+ uint32_t wg_m;
1005
+ uint32_t wg_n;
1006
+ #ifndef __EMSCRIPTEN__
1007
+ if (ctx->supports_subgroup_matrix) {
1008
+ // The total number of subgroups/workgroups needed per matrix.
1009
+ uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
1010
+ wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1011
+ uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
1012
+ wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1013
+ } else {
1014
+ #endif
1015
+ uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
1016
+ uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
1017
+ wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1018
+ wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
1019
+ #ifndef __EMSCRIPTEN__
1020
+ }
1021
+ #endif
1022
+
1023
+ wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1024
+ }
1025
+ }
1026
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
1027
+ }
1028
+
1029
+ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1030
+ ggml_tensor * Q,
1031
+ ggml_tensor * K,
1032
+ ggml_tensor * V,
1033
+ ggml_tensor * mask,
1034
+ ggml_tensor * sinks,
1035
+ ggml_tensor * dst) {
1036
+ float scale = *(float *) dst->op_params;
1037
+ float max_bias;
1038
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1039
+ float logit_softcap;
1040
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
1041
+ if (logit_softcap != 0.0f) {
1042
+ scale /= logit_softcap;
1043
+ }
1044
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
1045
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1046
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1047
+
1048
+ const int has_mask = (mask != nullptr);
1049
+ const int has_sinks = (sinks != nullptr);
1050
+
1051
+ std::vector<uint32_t> params = {
1052
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
1053
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
1054
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
1055
+ has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
1056
+ has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
1057
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1058
+ (uint32_t) Q->ne[2], // number of heads
1059
+ (uint32_t) Q->ne[1], // sequence length (Q)
1060
+ (uint32_t) K->ne[1], // sequence length (K/V)
1061
+ (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
1062
+ (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
1063
+ (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
1064
+ (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
1065
+ (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
1066
+ (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
1067
+ (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
1068
+ (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
1069
+ (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
1070
+ has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
1071
+ (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
1072
+ *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
1073
+ *(uint32_t *) &max_bias,
1074
+ *(uint32_t *) &logit_softcap,
1075
+ *(uint32_t *) &n_head_log2,
1076
+ *(uint32_t *) &m0,
1077
+ *(uint32_t *) &m1
1078
+
1079
+ };
1080
+ std::vector<wgpu::BindGroupEntry> entries = {
1081
+ { .binding = 0,
1082
+ .buffer = ggml_webgpu_tensor_buf(Q),
1083
+ .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
1084
+ .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
1085
+ { .binding = 1,
1086
+ .buffer = ggml_webgpu_tensor_buf(K),
1087
+ .offset = ggml_webgpu_tensor_align_offset(ctx, K),
1088
+ .size = ggml_webgpu_tensor_binding_size(ctx, K) },
1089
+ { .binding = 2,
1090
+ .buffer = ggml_webgpu_tensor_buf(V),
1091
+ .offset = ggml_webgpu_tensor_align_offset(ctx, V),
1092
+ .size = ggml_webgpu_tensor_binding_size(ctx, V) }
1093
+ };
1094
+ uint32_t binding_index = 3;
1095
+ if (has_mask) {
1096
+ entries.push_back({ .binding = binding_index++,
1097
+ .buffer = ggml_webgpu_tensor_buf(mask),
1098
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
1099
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
1100
+ }
1101
+ if (has_sinks) {
1102
+ entries.push_back({ .binding = binding_index++,
1103
+ .buffer = ggml_webgpu_tensor_buf(sinks),
1104
+ .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
1105
+ .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
1106
+ }
1107
+ entries.push_back({ .binding = binding_index++,
1108
+ .buffer = ggml_webgpu_tensor_buf(dst),
1109
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1110
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1111
+
1112
+ bool kv_direct =
1113
+ (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
1114
+
1115
+ flash_attn_pipeline_key key = {
1116
+ .q_type = Q->type,
1117
+ .kv_type = K->type,
1118
+ .dst_type = dst->type,
1119
+ .head_dim_qk = (uint32_t) Q->ne[0],
1120
+ .head_dim_v = (uint32_t) V->ne[0],
1121
+ .kv_direct = kv_direct,
1122
+ .has_mask = static_cast<bool>(has_mask),
1123
+ .has_sinks = static_cast<bool>(has_sinks),
1124
+ .uses_logit_softcap = logit_softcap != 0.0f,
1125
+ };
1126
+
1127
+ webgpu_pipeline pipeline;
1128
+ ggml_webgpu_flash_attn_shader_decisions decisions = {};
1129
+
1130
+ auto it = ctx->flash_attn_pipelines.find(key);
1131
+ if (it != ctx->flash_attn_pipelines.end()) {
1132
+ pipeline = it->second;
1133
+ decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
1134
+ } else {
1135
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
1136
+ it = ctx->flash_attn_pipelines.find(key);
1137
+ if (it != ctx->flash_attn_pipelines.end()) {
1138
+ pipeline = it->second;
1139
+ decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
1140
+ } else {
1141
+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
1142
+ .head_dim_qk = (uint32_t) Q->ne[0],
1143
+ .head_dim_v = (uint32_t) V->ne[0],
1144
+ .kv_direct = kv_direct,
1145
+ .has_mask = static_cast<bool>(has_mask),
1146
+ .has_sinks = static_cast<bool>(has_sinks),
1147
+ .uses_logit_softcap = logit_softcap != 0.0f,
1148
+ .sg_mat_m = ctx->sg_mat_m,
1149
+ .sg_mat_n = ctx->sg_mat_n,
1150
+ .sg_mat_k = ctx->sg_mat_k,
1151
+ .wg_mem_limit_bytes =
1152
+ ctx->limits.maxComputeWorkgroupStorageSize,
1153
+ .max_subgroup_size = ctx->max_subgroup_size };
1154
+
1155
+ ggml_webgpu_processed_shader processed =
1156
+ ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
1157
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1158
+ pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
1159
+ ctx->flash_attn_pipelines.emplace(key, pipeline);
1160
+ decisions = processed.decisions;
1161
+ }
1162
+ }
1163
+
1164
+ uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
1165
+ uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
1166
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
645
1167
  }
646
1168
 
647
- static void ggml_webgpu_binary_op(webgpu_context & ctx,
648
- ggml_tensor * src0,
649
- ggml_tensor * src1,
650
- ggml_tensor * dst,
651
- wgpu::ComputePipeline & pipeline,
652
- bool in_place) {
1169
+ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1170
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1171
+ ggml_unary_op unary_op = ggml_get_unary_op(dst);
1172
+ uint32_t inplace = ggml_webgpu_tensor_equal(src, dst);
1173
+
1174
+ std::vector<uint32_t> params = {
1175
+ ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1176
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1177
+ // Convert byte-strides to element-strides
1178
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1179
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1180
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1181
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1182
+ // Logical shapes
1183
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
1184
+ (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
1185
+ };
1186
+
1187
+ switch (unary_op) {
1188
+ case GGML_UNARY_OP_XIELU:
1189
+ {
1190
+ // Get float parameters and reinterpret their bit patterns as uint32_t
1191
+ // for passing through the params buffer
1192
+ float alpha_n = ggml_get_op_params_f32(dst, 1);
1193
+ float alpha_p = ggml_get_op_params_f32(dst, 2);
1194
+ float beta = ggml_get_op_params_f32(dst, 3);
1195
+ float eps = ggml_get_op_params_f32(dst, 4);
1196
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1197
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1198
+ params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1199
+ params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1200
+ break;
1201
+ }
1202
+ default:
1203
+ break;
1204
+ }
1205
+
1206
+ std::vector<wgpu::BindGroupEntry> entries = {
1207
+ { .binding = 0,
1208
+ .buffer = ggml_webgpu_tensor_buf(src),
1209
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1210
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1211
+ };
1212
+ if (!inplace) {
1213
+ entries.push_back({ .binding = 1,
1214
+ .buffer = ggml_webgpu_tensor_buf(dst),
1215
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1216
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1217
+ }
1218
+
1219
+ uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
1220
+ return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x);
1221
+ }
1222
+
1223
+ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1224
+ ggml_tensor * src0,
1225
+ ggml_tensor * src1,
1226
+ ggml_tensor * dst,
1227
+ webgpu_pipeline & pipeline,
1228
+ bool inplace) {
653
1229
  std::vector<uint32_t> params = {
654
1230
  (uint32_t) ggml_nelements(dst),
655
1231
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@@ -678,43 +1254,210 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
678
1254
  .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
679
1255
  .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
680
1256
  };
681
- if (!in_place) {
1257
+ if (!inplace) {
682
1258
  entries.push_back({ .binding = 2,
683
1259
  .buffer = ggml_webgpu_tensor_buf(dst),
684
1260
  .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
685
1261
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
686
1262
  }
687
1263
 
688
- size_t max_wg_size = ctx->max_wg_size_x;
689
- uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
690
- ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
1264
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1265
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
691
1266
  }
692
1267
 
693
- static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
694
- bool in_place = ggml_webgpu_tensor_equal(src, dst);
1268
+ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1269
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
695
1270
 
696
- uint32_t eps;
697
- memcpy(&eps, dst->op_params, sizeof(float));
1271
+ std::vector<uint32_t> params = {
1272
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1273
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1274
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1275
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1276
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1277
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1278
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1279
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1280
+ (uint32_t) src->ne[0],
1281
+ (uint32_t) src->ne[1],
1282
+ (uint32_t) src->ne[2],
1283
+ (uint32_t) src->ne[3],
1284
+ *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
1285
+ };
1286
+
1287
+ std::vector<wgpu::BindGroupEntry> entries = {
1288
+ { .binding = 0,
1289
+ .buffer = ggml_webgpu_tensor_buf(src),
1290
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1291
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1292
+ };
1293
+ if (!inplace) {
1294
+ entries.push_back({ .binding = 1,
1295
+ .buffer = ggml_webgpu_tensor_buf(dst),
1296
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1297
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1298
+ }
1299
+
1300
+ return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src));
1301
+ }
1302
+
1303
+ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1304
+ ggml_tensor * src0,
1305
+ ggml_tensor * src1,
1306
+ ggml_tensor * src2,
1307
+ ggml_tensor * dst) {
1308
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1309
+ const int has_freq_factor = (src2 != nullptr);
1310
+
1311
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1312
+ const int mode = ((int32_t *) dst->op_params)[2];
1313
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1314
+
1315
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1316
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1317
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1318
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1319
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1320
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1321
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1322
+
1323
+ int sections[4];
1324
+ memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
1325
+
1326
+ float theta_scale = powf(freq_base, -2.0f / n_dims);
1327
+
1328
+ float corr_dims[2];
1329
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1330
+
1331
+ std::vector<uint32_t> params = {
1332
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1333
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1334
+ src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1335
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1336
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1337
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1338
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1339
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1340
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1341
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1342
+ (uint32_t) ggml_nelements(src0) / 2,
1343
+ (uint32_t) src0->ne[0],
1344
+ (uint32_t) src0->ne[1],
1345
+ (uint32_t) src0->ne[2],
1346
+ (uint32_t) n_dims,
1347
+ (uint32_t) mode,
1348
+ *(uint32_t *) &theta_scale,
1349
+ *(uint32_t *) &attn_factor,
1350
+ *(uint32_t *) &freq_scale,
1351
+ *(uint32_t *) &ext_factor,
1352
+ *(uint32_t *) &corr_dims[0],
1353
+ *(uint32_t *) &corr_dims[1],
1354
+ (uint32_t) sections[0],
1355
+ (uint32_t) sections[1],
1356
+ (uint32_t) sections[2],
1357
+ (uint32_t) sections[3]
1358
+ };
1359
+
1360
+ std::vector<wgpu::BindGroupEntry> entries = {
1361
+ { .binding = 0,
1362
+ .buffer = ggml_webgpu_tensor_buf(src0),
1363
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1364
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1365
+ { .binding = 1,
1366
+ .buffer = ggml_webgpu_tensor_buf(src1),
1367
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1368
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
1369
+ };
1370
+ uint32_t dst_binding = 2;
1371
+ if (has_freq_factor) {
1372
+ dst_binding = 3;
1373
+ entries.push_back({ .binding = 2,
1374
+ .buffer = ggml_webgpu_tensor_buf(src2),
1375
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1376
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1377
+ }
1378
+ if (!inplace) {
1379
+ entries.push_back({ .binding = dst_binding,
1380
+ .buffer = ggml_webgpu_tensor_buf(dst),
1381
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1382
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1383
+ }
1384
+
1385
+ webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
1386
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1387
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1388
+ }
1389
+
1390
+ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
1391
+ const int split = (src1 != nullptr);
1392
+
1393
+ std::vector<uint32_t> params = {
1394
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1395
+ src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1396
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1397
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1398
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1399
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1400
+ src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
1401
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1402
+ src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
1403
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1404
+ src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
1405
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1406
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1407
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1408
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1409
+ (uint32_t) ggml_nelements(dst),
1410
+ (uint32_t) dst->ne[0],
1411
+ (uint32_t) dst->ne[1],
1412
+ (uint32_t) dst->ne[2],
1413
+ (uint32_t) ((int32_t *) dst->op_params)[1], // swapped
1414
+ *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
1415
+ *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
1416
+ };
1417
+
1418
+ std::vector<wgpu::BindGroupEntry> entries = {
1419
+ { .binding = 0,
1420
+ .buffer = ggml_webgpu_tensor_buf(src0),
1421
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1422
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1423
+ };
1424
+ uint32_t dst_binding = 1;
1425
+ if (split) {
1426
+ dst_binding = 2;
1427
+ entries.push_back({ .binding = 1,
1428
+ .buffer = ggml_webgpu_tensor_buf(src1),
1429
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1430
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1431
+ }
1432
+ entries.push_back({ .binding = dst_binding,
1433
+ .buffer = ggml_webgpu_tensor_buf(dst),
1434
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1435
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1436
+
1437
+ webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
1438
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1439
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1440
+ }
1441
+
1442
+ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1443
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
698
1444
 
699
1445
  std::vector<uint32_t> params = {
700
1446
  (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1447
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1448
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1449
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1450
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1451
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1452
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1453
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1454
+ (uint32_t) ggml_nelements(dst),
1455
+ (uint32_t) src->ne[0],
1456
+ (uint32_t) src->ne[1],
1457
+ (uint32_t) src->ne[2],
1458
+ *(uint32_t *) dst->op_params, // scale
1459
+ *(uint32_t *) &dst->op_params[1] // bias
701
1460
  };
702
- if (!in_place) {
703
- params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
704
- }
705
- params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
706
- params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
707
- params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
708
- if (!in_place) {
709
- params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
710
- params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
711
- params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
712
- }
713
- params.push_back((uint32_t) src->ne[0]);
714
- params.push_back((uint32_t) src->ne[1]);
715
- params.push_back((uint32_t) src->ne[2]);
716
- params.push_back((uint32_t) src->ne[3]);
717
- params.push_back(eps); // epsilon, will be bitcast to float in shader
718
1461
 
719
1462
  std::vector<wgpu::BindGroupEntry> entries = {
720
1463
  { .binding = 0,
@@ -722,33 +1465,100 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
722
1465
  .offset = ggml_webgpu_tensor_align_offset(ctx, src),
723
1466
  .size = ggml_webgpu_tensor_binding_size(ctx, src) }
724
1467
  };
725
- if (!in_place) {
1468
+ if (!inplace) {
726
1469
  entries.push_back({ .binding = 1,
727
1470
  .buffer = ggml_webgpu_tensor_buf(dst),
728
1471
  .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
729
1472
  .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
730
1473
  }
731
1474
 
732
- wgpu::ComputePipeline pipeline;
733
- if (in_place) {
734
- pipeline = ctx->rms_norm_ip_pipeline;
735
- } else {
736
- pipeline = ctx->rms_norm_pipeline;
1475
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1476
+ return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x);
1477
+ }
1478
+
1479
+ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
1480
+ ggml_tensor * src0,
1481
+ ggml_tensor * src1,
1482
+ ggml_tensor * src2,
1483
+ ggml_tensor * dst) {
1484
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1485
+ const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
1486
+ const int has_sink = (src2 != nullptr);
1487
+ float max_bias;
1488
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1489
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
1490
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1491
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1492
+
1493
+ std::vector<uint32_t> params = {
1494
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1495
+ mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1496
+ has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1497
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1498
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1499
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1500
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1501
+ mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
1502
+ mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
1503
+ mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
1504
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1505
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1506
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1507
+ (uint32_t) ggml_nelements(dst),
1508
+ (uint32_t) src0->ne[0],
1509
+ (uint32_t) src0->ne[1],
1510
+ (uint32_t) src0->ne[2],
1511
+ mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
1512
+ mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
1513
+ *(uint32_t *) dst->op_params, // scale
1514
+ *(uint32_t *) &max_bias,
1515
+ *(uint32_t *) &n_head_log2,
1516
+ *(uint32_t *) &m0,
1517
+ *(uint32_t *) &m1
1518
+ };
1519
+
1520
+ std::vector<wgpu::BindGroupEntry> entries = {
1521
+ { .binding = 0,
1522
+ .buffer = ggml_webgpu_tensor_buf(src0),
1523
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1524
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
1525
+ };
1526
+ uint32_t binding_num = 1;
1527
+ if (mask_type < 2) {
1528
+ entries.push_back({ .binding = binding_num,
1529
+ .buffer = ggml_webgpu_tensor_buf(src1),
1530
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1531
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1532
+ binding_num++;
1533
+ }
1534
+ if (has_sink) {
1535
+ entries.push_back({ .binding = binding_num,
1536
+ .buffer = ggml_webgpu_tensor_buf(src2),
1537
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1538
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1539
+ binding_num++;
1540
+ }
1541
+ if (!inplace) {
1542
+ entries.push_back({ .binding = binding_num,
1543
+ .buffer = ggml_webgpu_tensor_buf(dst),
1544
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1545
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
737
1546
  }
738
- size_t max_wg_size = ctx->max_wg_size_x;
739
- uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
740
- ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
1547
+
1548
+ return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1549
+ ggml_nrows(dst));
741
1550
  }
742
1551
 
743
- // Returns true if node has enqueued work into the queue, false otherwise
744
- static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
1552
+ // Returns the encoded command, or std::nullopt if the operation is a no-op
1553
+ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
745
1554
  if (ggml_is_empty(node)) {
746
- return false;
1555
+ return std::nullopt;
747
1556
  }
748
1557
  WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
749
1558
 
750
1559
  ggml_tensor * src0 = node->src[0];
751
1560
  ggml_tensor * src1 = node->src[1];
1561
+ ggml_tensor * src2 = node->src[2];
752
1562
 
753
1563
  switch (node->op) {
754
1564
  // no-ops
@@ -757,40 +1567,53 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
757
1567
  case GGML_OP_PERMUTE:
758
1568
  case GGML_OP_TRANSPOSE:
759
1569
  case GGML_OP_RESHAPE:
760
- return false;
1570
+ return std::nullopt;
761
1571
  case GGML_OP_CPY:
762
- ggml_webgpu_cpy(ctx, src0, node);
763
- break;
1572
+ case GGML_OP_CONT:
1573
+ return ggml_webgpu_cpy(ctx, src0, node);
764
1574
  case GGML_OP_SET_ROWS:
765
- ggml_webgpu_set_rows(ctx, src0, src1, node);
766
- break;
1575
+ return ggml_webgpu_set_rows(ctx, src0, src1, node);
767
1576
  case GGML_OP_GET_ROWS:
768
- ggml_webgpu_get_rows(ctx, src0, src1, node);
769
- break;
1577
+ return ggml_webgpu_get_rows(ctx, src0, src1, node);
770
1578
  case GGML_OP_MUL_MAT:
771
- ggml_webgpu_mul_mat(ctx, src0, src1, node);
772
- break;
1579
+ return ggml_webgpu_mul_mat(ctx, src0, src1, node);
1580
+ case GGML_OP_FLASH_ATTN_EXT:
1581
+ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
773
1582
  case GGML_OP_ADD:
774
- if (ggml_webgpu_tensor_equal(src0, node)) {
775
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
776
- } else {
777
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
1583
+ {
1584
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1585
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
1586
+ }
1587
+ case GGML_OP_SUB:
1588
+ {
1589
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1590
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
778
1591
  }
779
- break;
780
1592
  case GGML_OP_MUL:
781
- if (ggml_webgpu_tensor_equal(src0, node)) {
782
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
783
- } else {
784
- ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
1593
+ {
1594
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1595
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
1596
+ }
1597
+ case GGML_OP_DIV:
1598
+ {
1599
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1600
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
785
1601
  }
786
- break;
787
1602
  case GGML_OP_RMS_NORM:
788
- ggml_webgpu_rms_norm(ctx, src0, node);
789
- break;
1603
+ return ggml_webgpu_rms_norm(ctx, src0, node);
1604
+ case GGML_OP_ROPE:
1605
+ return ggml_webgpu_rope(ctx, src0, src1, src2, node);
1606
+ case GGML_OP_GLU:
1607
+ return ggml_webgpu_glu(ctx, src0, src1, node);
1608
+ case GGML_OP_SCALE:
1609
+ return ggml_webgpu_scale(ctx, src0, node);
1610
+ case GGML_OP_SOFT_MAX:
1611
+ return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
1612
+ case GGML_OP_UNARY:
1613
+ return ggml_webgpu_unary_op(ctx, src0, node);
790
1614
  default:
791
- return false;
1615
+ return std::nullopt;
792
1616
  }
793
- return true;
794
1617
  }
795
1618
 
796
1619
  static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
@@ -799,13 +1622,36 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
799
1622
  ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
800
1623
  webgpu_context ctx = backend_ctx->webgpu_ctx;
801
1624
 
1625
+ WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
1626
+
1627
+ ctx->inflight_threads++;
1628
+
1629
+ std::vector<webgpu_command> commands;
1630
+ std::vector<webgpu_submission_futures> futures;
802
1631
  for (int i = 0; i < cgraph->n_nodes; i++) {
803
- ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
1632
+ if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
1633
+ commands.push_back(*cmd);
1634
+ }
1635
+ // compute the batch size based on the number of inflight threads
1636
+ uint32_t inflight_threads = ctx->inflight_threads;
1637
+ uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
1638
+ WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1639
+ if (commands.size() >= batch_size) {
1640
+ futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
1641
+ // Process events and check for completed submissions
1642
+ ctx->instance.ProcessEvents();
1643
+ ggml_backend_webgpu_wait(ctx, futures, false);
1644
+ commands.clear();
1645
+ }
1646
+ }
1647
+ if (!commands.empty()) {
1648
+ webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
1649
+ futures.push_back(new_futures);
804
1650
  }
805
1651
 
806
- ggml_backend_webgpu_submit_queue(ctx);
807
- ggml_backend_webgpu_wait_on_submission(ctx);
808
-
1652
+ ggml_backend_webgpu_wait(ctx, futures);
1653
+ ctx->inflight_threads--;
1654
+ WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
809
1655
  return GGML_STATUS_SUCCESS;
810
1656
  }
811
1657
 
@@ -831,7 +1677,6 @@ static ggml_backend_i ggml_backend_webgpu_i = {
831
1677
  /* GGML Backend Buffer Interface */
832
1678
 
833
1679
  static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
834
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
835
1680
  ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
836
1681
  ctx->buffer.Destroy();
837
1682
  }
@@ -852,16 +1697,19 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
852
1697
  return;
853
1698
  }
854
1699
 
855
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
856
- << offset << ", " << size << ")");
1700
+ WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
857
1701
 
858
1702
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
859
1703
 
1704
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
1705
+ << ", " << offset << ", " << size << ")");
1706
+
860
1707
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
861
1708
 
862
1709
  // This is a trick to set all bytes of a u32 to the same 1 byte value.
863
1710
  uint32_t val32 = (uint32_t) value * 0x01010101;
864
1711
  ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
1712
+ WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx);
865
1713
  }
866
1714
 
867
1715
  static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
@@ -869,11 +1717,13 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
869
1717
  const void * data,
870
1718
  size_t offset,
871
1719
  size_t size) {
872
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
873
- << offset << ", " << size << ")");
1720
+ WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
874
1721
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
875
1722
  webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
876
1723
 
1724
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1725
+ << ", " << offset << ", " << size << ")");
1726
+
877
1727
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
878
1728
 
879
1729
  webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
@@ -893,8 +1743,17 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
893
1743
  remaining_size);
894
1744
  } else {
895
1745
  // wait for WriteBuffer to complete
896
- ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
1746
+ webgpu_ctx->instance.WaitAny(
1747
+ webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
1748
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
1749
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
1750
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
1751
+ std::string(message).c_str());
1752
+ }
1753
+ }),
1754
+ UINT64_MAX);
897
1755
  }
1756
+ WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx);
898
1757
  }
899
1758
 
900
1759
  static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
@@ -902,12 +1761,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
902
1761
  void * data,
903
1762
  size_t offset,
904
1763
  size_t size) {
905
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
906
- << offset << ", " << size << ")");
907
-
908
- ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
909
- webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
910
- wgpu::Device device = webgpu_ctx->device;
1764
+ WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
1765
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1766
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1767
+ << ", " << offset << ", " << size << ")");
1768
+ webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
1769
+ wgpu::Device device = webgpu_ctx->device;
911
1770
 
912
1771
  size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
913
1772
 
@@ -944,12 +1803,15 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
944
1803
  // Copy the data from the mapped range to the output buffer
945
1804
  std::memcpy(data, mapped_range, size);
946
1805
  webgpu_ctx->get_tensor_staging_buf.Unmap();
1806
+ WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx);
947
1807
  }
948
1808
 
949
1809
  static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
950
1810
  WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
1811
+ WEBGPU_CPU_PROFILE_TOTAL_START(clear);
951
1812
  ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
952
1813
  ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
1814
+ WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx);
953
1815
  }
954
1816
 
955
1817
  static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
@@ -975,16 +1837,19 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
975
1837
 
976
1838
  static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
977
1839
  size_t size) {
978
- WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
1840
+ static std::atomic<int> buffer_count;
1841
+ int buffer_id = buffer_count++;
1842
+ std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
1843
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
979
1844
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
980
1845
 
981
1846
  wgpu::Buffer buf;
982
- ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
983
- (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
1847
+ ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
984
1848
  wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
985
- "allocated_buffer");
1849
+ buf_name.c_str());
986
1850
 
987
- ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
1851
+ ggml_backend_webgpu_buffer_context * buf_ctx =
1852
+ new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
988
1853
 
989
1854
  return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
990
1855
  }
@@ -1016,9 +1881,18 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
1016
1881
 
1017
1882
  static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1018
1883
  ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1019
- // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
1020
- *free = ctx->webgpu_ctx->limits.maxBufferSize;
1021
- *total = ctx->webgpu_ctx->limits.maxBufferSize;
1884
+ // TODO: for now, return maxBufferSize as both free and total memory
1885
+ // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
1886
+ uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
1887
+ // If we're on a 32-bit system, clamp to UINTPTR_MAX
1888
+ #if UINTPTR_MAX < UINT64_MAX
1889
+ uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
1890
+ if (max_buffer_size > max_ptr_size) {
1891
+ max_buffer_size = max_ptr_size;
1892
+ }
1893
+ #endif
1894
+ *free = static_cast<size_t>(max_buffer_size);
1895
+ *total = static_cast<size_t>(max_buffer_size);
1022
1896
  }
1023
1897
 
1024
1898
  static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
@@ -1044,168 +1918,609 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
1044
1918
  return reinterpret_cast<ggml_guid_t>((void *) guid_str);
1045
1919
  }
1046
1920
 
1047
- // The max workgroup size is a common constant
1048
- static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
1921
+ // Workgroup size is a common constant
1922
+ static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
1049
1923
  std::vector<wgpu::ConstantEntry> constants(1);
1050
1924
  constants[0].key = "wg_size";
1051
- constants[0].value = webgpu_ctx->max_wg_size_x;
1925
+ constants[0].value = wg_size;
1052
1926
  return constants;
1053
1927
  }
1054
1928
 
1055
1929
  static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
1056
1930
  // we use the maximum workgroup size for the memset pipeline
1057
- size_t max_wg_size = webgpu_ctx->max_wg_size_x;
1058
- size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
1931
+ size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
1059
1932
  // Size the bytes_per_thread so that the largest buffer size can be handled
1060
- webgpu_ctx->memset_bytes_per_thread =
1061
- (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
1933
+ webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads);
1062
1934
  std::vector<wgpu::ConstantEntry> constants(2);
1063
- constants[0].key = "wg_size";
1064
- constants[0].value = max_wg_size;
1065
- constants[1].key = "bytes_per_thread";
1066
- constants[1].value = webgpu_ctx->memset_bytes_per_thread;
1067
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
1935
+ constants[0].key = "wg_size";
1936
+ constants[0].value = WEBGPU_MAX_WG_SIZE;
1937
+ constants[1].key = "bytes_per_thread";
1938
+ constants[1].value = webgpu_ctx->memset_bytes_per_thread;
1939
+ webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants);
1068
1940
  }
1069
1941
 
1070
1942
  static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
1071
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
1072
- wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
1073
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
1074
- wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
1075
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
1076
- wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
1077
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
1078
- wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
1079
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
1080
- wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
1081
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
1082
- wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
1083
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
1084
- wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
1085
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
1086
- wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
1087
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
1088
- wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
1089
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
1090
- wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
1091
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
1092
- wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
1093
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
1094
- wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
1095
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
1096
- wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
1097
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
1098
- wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
1099
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
1100
- wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
1101
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
1102
- wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
1103
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
1104
- wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
1105
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
1106
- wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
1107
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
1108
- wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
1109
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
1110
- wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
1111
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
1112
- wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
1113
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
1114
- wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
1943
+ // Q4/Q5/Q8 classic quantizations
1944
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
1945
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
1946
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
1947
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
1948
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
1949
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
1950
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
1951
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
1952
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
1953
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
1954
+
1955
+ // K-quantizations
1956
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
1957
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
1958
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
1959
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
1960
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
1961
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
1962
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
1963
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
1964
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
1965
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
1966
+
1967
+ // IQ quantizations (2-, 3-, 4-bit variants)
1968
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
1969
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
1970
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
1971
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
1972
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
1973
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
1974
+
1975
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
1976
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
1977
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
1978
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
1979
+
1980
+ // 1-bit and 4-bit IQ variants
1981
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
1982
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
1983
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
1984
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
1985
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
1986
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
1987
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
1988
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
1989
+
1990
+ std::string proc_mul_mat_f32_f32;
1991
+ std::string proc_mul_mat_f32_f32_vec;
1992
+ std::string proc_mul_mat_f16_f32;
1993
+ std::string proc_mul_mat_f16_f32_vec;
1994
+ std::string proc_mul_mat_f16_f16;
1995
+ std::string proc_mul_mat_f16_f16_vec;
1996
+ std::string proc_mul_mat_q4_0_f32;
1997
+ std::string proc_mul_mat_q4_0_f32_vec;
1998
+
1999
+ std::vector<wgpu::ConstantEntry> mul_mat_constants;
2000
+ #ifndef __EMSCRIPTEN__
2001
+ if (webgpu_ctx->supports_subgroup_matrix) {
2002
+ std::map<std::string, std::string> sg_matrix_repls;
2003
+ sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
2004
+ sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
2005
+ sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
2006
+ sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
2007
+ sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
2008
+ sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
2009
+ sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
2010
+ sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
2011
+ sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
2012
+
2013
+ proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
2014
+ proc_mul_mat_f32_f32_vec =
2015
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
2016
+ proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
2017
+ proc_mul_mat_f16_f32_vec =
2018
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
2019
+ proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
2020
+ proc_mul_mat_f16_f16_vec =
2021
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
2022
+ proc_mul_mat_q4_0_f32 =
2023
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
2024
+ proc_mul_mat_q4_0_f32_vec =
2025
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
2026
+ } else {
2027
+ #endif
2028
+ mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
2029
+ mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
2030
+ mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
2031
+
2032
+ std::map<std::string, std::string> reg_repls;
2033
+ reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
2034
+ reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
2035
+
2036
+ proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
2037
+ proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
2038
+ proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
2039
+ proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
2040
+ proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
2041
+ proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
2042
+ proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
2043
+ proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
2044
+ #ifndef __EMSCRIPTEN__
2045
+ }
2046
+ #endif
2047
+
2048
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2049
+ webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
2050
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2051
+ webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
2052
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2053
+ webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
2054
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2055
+ webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
2056
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2057
+ webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
2058
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2059
+ webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
2060
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2061
+ webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
2062
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2063
+ webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
2064
+
2065
+ std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
2066
+ mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
2067
+ mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2068
+ mul_mat_vec_constants[1].key = "TILE_K";
2069
+ mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
2070
+ mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
2071
+ mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
2072
+
2073
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2074
+ webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
2075
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2076
+ webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
2077
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2078
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
2079
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2080
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
2081
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2082
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
2083
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2084
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
2085
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2086
+ webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
1115
2087
  }
1116
2088
 
1117
2089
  static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
1118
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
1119
- ggml_webgpu_max_wg_size_entry(webgpu_ctx));
2090
+ webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline(
2091
+ webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
2092
+ webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline(
2093
+ webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
1120
2094
  }
1121
2095
 
1122
2096
  static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
1123
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1124
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
1125
- "get_rows_f32_vec", constants);
1126
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
1127
- "get_rows_f32", constants);
1128
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
1129
- "get_rows_f16", constants);
1130
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
1131
- "get_rows_i32", constants);
1132
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
1133
- "get_rows_q4_0", constants);
1134
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
1135
- "get_rows_q4_1", constants);
1136
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
1137
- "get_rows_q5_0", constants);
1138
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
1139
- "get_rows_q5_1", constants);
1140
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
1141
- "get_rows_q8_0", constants);
1142
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
1143
- "get_rows_q2_k", constants);
1144
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
1145
- "get_rows_q3_k", constants);
1146
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
1147
- "get_rows_q4_k", constants);
1148
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
1149
- "get_rows_q5_k", constants);
1150
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
1151
- "get_rows_q6_k", constants);
1152
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
1153
- wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
1154
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
1155
- wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
1156
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s,
1157
- "get_rows_iq2_s", constants);
1158
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
1159
- wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
1160
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s,
1161
- "get_rows_iq3_s", constants);
1162
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s,
1163
- "get_rows_iq1_s", constants);
1164
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m,
1165
- "get_rows_iq1_m", constants);
1166
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
1167
- wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
1168
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
1169
- wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
2097
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2098
+
2099
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
2100
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
2101
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] =
2102
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
2103
+
2104
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
2105
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
2106
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
2107
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
2108
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
2109
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
2110
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
2111
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
2112
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
2113
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
2114
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
2115
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
2116
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
2117
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
2118
+
2119
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
2120
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
2121
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
2122
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
2123
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
2124
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
2125
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
2126
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
2127
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
2128
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
2129
+
2130
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] =
2131
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
2132
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
2133
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
2134
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
2135
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
2136
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] =
2137
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
2138
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
2139
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
2140
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
2141
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
2142
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
2143
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
2144
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
2145
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
2146
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
2147
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
1170
2148
  }
1171
2149
 
1172
2150
  static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
1173
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
1174
- ggml_webgpu_max_wg_size_entry(webgpu_ctx));
2151
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2152
+
2153
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
2154
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2155
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
2156
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2157
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
2158
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2159
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
2160
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
1175
2161
  }
1176
2162
 
1177
2163
  static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
1178
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1179
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
1180
- constants);
1181
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
1182
- constants);
1183
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
1184
- "add_in_place_f32", constants);
1185
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
1186
- "add_in_place_f16", constants);
2164
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2165
+
2166
+ webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
2167
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants);
2168
+ webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
2169
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants);
2170
+ webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
2171
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
2172
+ webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
2173
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
2174
+ }
2175
+
2176
+ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
2177
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2178
+
2179
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
2180
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants);
2181
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
2182
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants);
2183
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
2184
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
2185
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
2186
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
1187
2187
  }
1188
2188
 
1189
2189
  static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
1190
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1191
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
1192
- constants);
1193
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
1194
- constants);
1195
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
1196
- "mul_in_place_f32", constants);
1197
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
1198
- "mul_in_place_f16", constants);
2190
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2191
+
2192
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
2193
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants);
2194
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
2195
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants);
2196
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
2197
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
2198
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
2199
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
2200
+ }
2201
+
2202
+ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
2203
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2204
+
2205
+ webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
2206
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants);
2207
+ webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
2208
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants);
2209
+ webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
2210
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
2211
+ webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
2212
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
1199
2213
  }
1200
2214
 
1201
2215
  static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
1202
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1203
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
1204
- constants);
1205
- ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
1206
- "rms_norm_in_place", constants);
2216
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2217
+
2218
+ webgpu_ctx->rms_norm_pipelines[0] =
2219
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2220
+ webgpu_ctx->rms_norm_pipelines[1] =
2221
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2222
+ }
2223
+
2224
+ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
2225
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2226
+
2227
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
2228
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2229
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] =
2230
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2231
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
2232
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2233
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] =
2234
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2235
+
2236
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
2237
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2238
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] =
2239
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2240
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
2241
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2242
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] =
2243
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2244
+ }
2245
+
2246
+ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
2247
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2248
+
2249
+ // REGLU
2250
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
2251
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2252
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
2253
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2254
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
2255
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2256
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
2257
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2258
+
2259
+ // GEGLU
2260
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
2261
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2262
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
2263
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2264
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
2265
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2266
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
2267
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2268
+
2269
+ // SWIGLU
2270
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
2271
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2272
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
2273
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2274
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] =
2275
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2276
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] =
2277
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2278
+
2279
+ // SWIGLU_OAI
2280
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
2281
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2282
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] =
2283
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2284
+
2285
+ // GEGLU_ERF
2286
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
2287
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2288
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
2289
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2290
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] =
2291
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2292
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] =
2293
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2294
+
2295
+ // GEGLU_QUICK
2296
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
2297
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2298
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
2299
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2300
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] =
2301
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2302
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] =
2303
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2304
+ }
2305
+
2306
+ static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
2307
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2308
+
2309
+ // ABS
2310
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] =
2311
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants);
2312
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] =
2313
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants);
2314
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] =
2315
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants);
2316
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] =
2317
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants);
2318
+
2319
+ // SGN
2320
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] =
2321
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants);
2322
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] =
2323
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants);
2324
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] =
2325
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants);
2326
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] =
2327
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants);
2328
+
2329
+ // NEG
2330
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] =
2331
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants);
2332
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] =
2333
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants);
2334
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] =
2335
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants);
2336
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] =
2337
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants);
2338
+
2339
+ // STEP
2340
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] =
2341
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants);
2342
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] =
2343
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants);
2344
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] =
2345
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants);
2346
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] =
2347
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants);
2348
+
2349
+ // TANH
2350
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] =
2351
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants);
2352
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] =
2353
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants);
2354
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] =
2355
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants);
2356
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] =
2357
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants);
2358
+
2359
+ // ELU
2360
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] =
2361
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants);
2362
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] =
2363
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants);
2364
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] =
2365
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants);
2366
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] =
2367
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants);
2368
+
2369
+ // RELU
2370
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] =
2371
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants);
2372
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] =
2373
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants);
2374
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] =
2375
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants);
2376
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] =
2377
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants);
2378
+
2379
+ // SIGMOID
2380
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] =
2381
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants);
2382
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] =
2383
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants);
2384
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] =
2385
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants);
2386
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] =
2387
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants);
2388
+
2389
+ // GELU
2390
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] =
2391
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants);
2392
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] =
2393
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants);
2394
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] =
2395
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants);
2396
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] =
2397
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants);
2398
+
2399
+ // GELU_QUICK
2400
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] =
2401
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
2402
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] =
2403
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
2404
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2405
+ webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants);
2406
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2407
+ webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants);
2408
+
2409
+ // SILU
2410
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] =
2411
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants);
2412
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] =
2413
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants);
2414
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] =
2415
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants);
2416
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] =
2417
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants);
2418
+
2419
+ // HARDSWISH
2420
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] =
2421
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants);
2422
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] =
2423
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants);
2424
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] =
2425
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants);
2426
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] =
2427
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants);
2428
+
2429
+ // HARDSIGMOID
2430
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] =
2431
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
2432
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] =
2433
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
2434
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2435
+ webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants);
2436
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2437
+ webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants);
2438
+
2439
+ // EXP
2440
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] =
2441
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants);
2442
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] =
2443
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants);
2444
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] =
2445
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants);
2446
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] =
2447
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants);
2448
+
2449
+ // GELU_ERF
2450
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] =
2451
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
2452
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] =
2453
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
2454
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] =
2455
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants);
2456
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] =
2457
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants);
2458
+
2459
+ // XIELU
2460
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] =
2461
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants);
2462
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] =
2463
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants);
2464
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] =
2465
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants);
2466
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] =
2467
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants);
2468
+
2469
+ // CEIL
2470
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] =
2471
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants);
2472
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] =
2473
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants);
2474
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] =
2475
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants);
2476
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] =
2477
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants);
1207
2478
  }
1208
2479
 
2480
+ static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
2481
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2482
+
2483
+ webgpu_ctx->scale_pipelines[0] =
2484
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants);
2485
+ webgpu_ctx->scale_pipelines[1] =
2486
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants);
2487
+ }
2488
+
2489
+ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
2490
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2491
+
2492
+ // f32 (no mask)
2493
+ webgpu_ctx->soft_max_pipelines[2][0][0] =
2494
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2495
+ webgpu_ctx->soft_max_pipelines[2][0][1] =
2496
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2497
+ webgpu_ctx->soft_max_pipelines[2][1][0] =
2498
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2499
+ webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
2500
+ webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2501
+
2502
+ // f32 mask (mask_type = 0)
2503
+ webgpu_ctx->soft_max_pipelines[0][0][0] =
2504
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2505
+ webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
2506
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2507
+ webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
2508
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2509
+ webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline(
2510
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants);
2511
+
2512
+ // f16 mask (mask_type = 1)
2513
+ webgpu_ctx->soft_max_pipelines[1][0][0] =
2514
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2515
+ webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
2516
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2517
+ webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
2518
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2519
+ webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline(
2520
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
2521
+ }
2522
+
2523
+ // TODO: move most initialization logic here
1209
2524
  static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
1210
2525
  GGML_UNUSED(params);
1211
2526
 
@@ -1225,12 +2540,12 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
1225
2540
  /* .device = */ dev,
1226
2541
  /* .context = */ &backend_ctx,
1227
2542
  };
1228
-
1229
2543
  return &backend;
1230
2544
  }
1231
2545
 
1232
2546
  static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
1233
2547
  // See GGML Backend Buffer Type Interface section
2548
+
1234
2549
  static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
1235
2550
  /* .iface = */ {
1236
2551
  /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
@@ -1287,6 +2602,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
1287
2602
 
1288
2603
  ggml_tensor * src0 = op->src[0];
1289
2604
  ggml_tensor * src1 = op->src[1];
2605
+ ggml_tensor * src2 = op->src[2];
2606
+
1290
2607
  // on smaller devices (or CI), tensors may be larger than the max storage buffer size
1291
2608
  if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
1292
2609
  (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
@@ -1304,28 +2621,36 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
1304
2621
  supports_op = true;
1305
2622
  break;
1306
2623
  case GGML_OP_ADD:
2624
+ case GGML_OP_SUB:
1307
2625
  case GGML_OP_MUL:
1308
- supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
1309
- (op->src[1]->type == op->type);
2626
+ case GGML_OP_DIV:
2627
+ // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
2628
+ // see https://github.com/ggml-org/llama.cpp/pull/16857
2629
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
2630
+ (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
1310
2631
  break;
1311
2632
  case GGML_OP_CPY:
2633
+ case GGML_OP_CONT:
2634
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
2635
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
2636
+ break;
1312
2637
  case GGML_OP_SET_ROWS:
1313
- supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
2638
+ supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
1314
2639
  break;
1315
2640
  case GGML_OP_GET_ROWS:
1316
- if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
1317
- op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
2641
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
2642
+ ggml_webgpu_supported_qtype(src0->type)) {
1318
2643
  supports_op = (op->type == GGML_TYPE_F32);
1319
2644
  }
1320
2645
  break;
1321
2646
  case GGML_OP_MUL_MAT:
1322
2647
  {
1323
- switch (op->src[1]->type) {
2648
+ switch (src1->type) {
1324
2649
  case GGML_TYPE_F16:
1325
- supports_op = (op->src[0]->type == GGML_TYPE_F16);
2650
+ supports_op |= (src0->type == GGML_TYPE_F16);
1326
2651
  break;
1327
2652
  case GGML_TYPE_F32:
1328
- switch (op->src[0]->type) {
2653
+ switch (src0->type) {
1329
2654
  case GGML_TYPE_F32:
1330
2655
  case GGML_TYPE_F16:
1331
2656
  case GGML_TYPE_Q4_0:
@@ -1357,19 +2682,110 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
1357
2682
  }
1358
2683
  break;
1359
2684
  }
2685
+ case GGML_OP_FLASH_ATTN_EXT:
2686
+ {
2687
+ if (!webgpu_ctx->supports_subgroup_matrix) {
2688
+ break;
2689
+ }
2690
+ // Head dimensions must fit in workgroup memory with minimum tile sizes
2691
+ size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
2692
+ const bool has_mask = op->src[3] != nullptr;
2693
+ const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
2694
+ (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
2695
+ const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
2696
+ webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
2697
+ has_mask, kv_direct);
2698
+ if (min_bytes > limit_bytes) {
2699
+ break;
2700
+ }
2701
+
2702
+ supports_op = src0->type == GGML_TYPE_F32 &&
2703
+ (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
2704
+ src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
2705
+ src2->type == src1->type && op->type == GGML_TYPE_F32;
2706
+ break;
2707
+ }
1360
2708
  case GGML_OP_RMS_NORM:
1361
- supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
2709
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
2710
+ break;
2711
+ case GGML_OP_ROPE:
2712
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
2713
+ break;
2714
+ case GGML_OP_GLU:
2715
+ switch (ggml_get_glu_op(op)) {
2716
+ case GGML_GLU_OP_REGLU:
2717
+ case GGML_GLU_OP_GEGLU:
2718
+ case GGML_GLU_OP_SWIGLU:
2719
+ case GGML_GLU_OP_GEGLU_ERF:
2720
+ case GGML_GLU_OP_GEGLU_QUICK:
2721
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
2722
+ break;
2723
+ case GGML_GLU_OP_SWIGLU_OAI:
2724
+ supports_op = op->type == GGML_TYPE_F32;
2725
+ break;
2726
+ default:
2727
+ break;
2728
+ }
2729
+ break;
2730
+ case GGML_OP_SCALE:
2731
+ supports_op = op->type == GGML_TYPE_F32;
2732
+ break;
2733
+ case GGML_OP_SOFT_MAX:
2734
+ supports_op = op->type == GGML_TYPE_F32;
2735
+ break;
2736
+ case GGML_OP_UNARY:
2737
+ {
2738
+ const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
2739
+
2740
+ switch (UNARY_OP) {
2741
+ case GGML_UNARY_OP_ABS:
2742
+ case GGML_UNARY_OP_SGN:
2743
+ case GGML_UNARY_OP_NEG:
2744
+ case GGML_UNARY_OP_STEP:
2745
+ case GGML_UNARY_OP_TANH:
2746
+ case GGML_UNARY_OP_ELU:
2747
+ case GGML_UNARY_OP_RELU:
2748
+ case GGML_UNARY_OP_SIGMOID:
2749
+ case GGML_UNARY_OP_GELU:
2750
+ case GGML_UNARY_OP_GELU_QUICK:
2751
+ case GGML_UNARY_OP_SILU:
2752
+ case GGML_UNARY_OP_HARDSWISH:
2753
+ case GGML_UNARY_OP_HARDSIGMOID:
2754
+ case GGML_UNARY_OP_EXP:
2755
+ case GGML_UNARY_OP_GELU_ERF:
2756
+ case GGML_UNARY_OP_XIELU:
2757
+ case GGML_UNARY_OP_CEIL:
2758
+ supports_op = supports_op =
2759
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2760
+ break;
2761
+ default:
2762
+ break;
2763
+ }
2764
+ }
1362
2765
  break;
2766
+
1363
2767
  default:
1364
2768
  break;
1365
2769
  }
1366
- #ifdef GGML_WEBGPU_DEBUG
2770
+ if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
2771
+ (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2772
+ (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2773
+ (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
2774
+ supports_op = false;
2775
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
2776
+ }
2777
+
1367
2778
  if (!supports_op) {
1368
- WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
1369
- << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
1370
- << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
2779
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
2780
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
2781
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
2782
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
2783
+ } else {
2784
+ WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
2785
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
2786
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
2787
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
1371
2788
  }
1372
- #endif
1373
2789
  return supports_op;
1374
2790
  }
1375
2791
 
@@ -1406,16 +2822,29 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
1406
2822
  }
1407
2823
 
1408
2824
  // TODO: Does this need to be thread safe? Is it only called once?
2825
+ // TODO: move most logic to device_init function so backend can be freed/initialized properly
1409
2826
  // Only one device is supported for now
1410
2827
  static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1411
2828
  GGML_ASSERT(index == 0);
1412
2829
  WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
1413
2830
 
2831
+ WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
2832
+
1414
2833
  ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
1415
2834
 
1416
2835
  webgpu_context ctx = reg_ctx->webgpu_ctx;
1417
2836
 
1418
2837
  wgpu::RequestAdapterOptions options = {};
2838
+
2839
+ #ifndef __EMSCRIPTEN__
2840
+ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2841
+ const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2842
+ wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2843
+ adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
2844
+ adapterTogglesDesc.enabledToggleCount = 2;
2845
+ options.nextInChain = &adapterTogglesDesc;
2846
+ #endif
2847
+
1419
2848
  ctx->instance.WaitAny(ctx->instance.RequestAdapter(
1420
2849
  &options, wgpu::CallbackMode::AllowSpontaneous,
1421
2850
  [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
@@ -1429,15 +2858,61 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
1429
2858
  GGML_ASSERT(ctx->adapter != nullptr);
1430
2859
 
1431
2860
  ctx->adapter.GetLimits(&ctx->limits);
1432
- ctx->max_wg_size_x = 288; // default value
1433
2861
 
1434
2862
  wgpu::AdapterInfo info{};
2863
+ #ifndef __EMSCRIPTEN__
2864
+ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2865
+ if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2866
+ info.nextInChain = &subgroup_matrix_configs;
2867
+ }
2868
+ #endif
1435
2869
  ctx->adapter.GetInfo(&info);
1436
2870
 
2871
+ wgpu::SupportedFeatures features;
2872
+ ctx->adapter.GetFeatures(&features);
2873
+ // we require f16 support
2874
+ GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2875
+
2876
+ #ifndef __EMSCRIPTEN__
2877
+ // Only support square f16 matrices of size 8 or 16 for now
2878
+ bool valid_subgroup_matrix_config = false;
2879
+ if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2880
+ for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2881
+ const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2882
+ if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2883
+ config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2884
+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2885
+ ctx->sg_mat_m = config.M;
2886
+ ctx->sg_mat_n = config.N;
2887
+ ctx->sg_mat_k = config.K;
2888
+ valid_subgroup_matrix_config = true;
2889
+ break;
2890
+ }
2891
+ }
2892
+ }
2893
+
2894
+ ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
2895
+ #endif
2896
+ // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2897
+ // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2898
+ ctx->max_subgroup_size = info.subgroupMaxSize;
2899
+
1437
2900
  // Initialize device
1438
- std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
1439
- wgpu::FeatureName::ImplicitDeviceSynchronization };
1440
- wgpu::DeviceDescriptor dev_desc;
2901
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
2902
+
2903
+ #ifndef __EMSCRIPTEN__
2904
+ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2905
+ if (ctx->supports_subgroup_matrix) {
2906
+ required_features.push_back(wgpu::FeatureName::Subgroups);
2907
+ required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2908
+ }
2909
+ #endif
2910
+
2911
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2912
+ required_features.push_back(wgpu::FeatureName::TimestampQuery);
2913
+ #endif
2914
+
2915
+ wgpu::DeviceDescriptor dev_desc;
1441
2916
  dev_desc.requiredLimits = &ctx->limits;
1442
2917
  dev_desc.requiredFeatures = required_features.data();
1443
2918
  dev_desc.requiredFeatureCount = required_features.size();
@@ -1445,15 +2920,35 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
1445
2920
  wgpu::CallbackMode::AllowSpontaneous,
1446
2921
  [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
1447
2922
  GGML_UNUSED(device);
1448
- GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
1449
- std::string(message).c_str());
2923
+ GGML_UNUSED(reason);
2924
+ GGML_UNUSED(message);
2925
+ //TODO: uncomment once proper free logic is in place
2926
+ //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2927
+ //std::string(message).c_str());
1450
2928
  });
1451
2929
  dev_desc.SetUncapturedErrorCallback(
1452
2930
  [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
1453
2931
  GGML_UNUSED(device);
1454
- GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
1455
- std::string(message).c_str());
2932
+ GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2933
+ std::string(message).c_str());
1456
2934
  });
2935
+
2936
+ #ifndef __EMSCRIPTEN__
2937
+ // Enable Dawn-specific toggles to increase native performance
2938
+ // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2939
+ // only for native performance?
2940
+ const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2941
+ "disable_polyfills_on_integer_div_and_mod" };
2942
+ const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2943
+ wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2944
+ deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2945
+ deviceTogglesDesc.enabledToggleCount = 4;
2946
+ deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2947
+ deviceTogglesDesc.disabledToggleCount = 1;
2948
+
2949
+ dev_desc.nextInChain = &deviceTogglesDesc;
2950
+ #endif
2951
+
1457
2952
  ctx->instance.WaitAny(ctx->adapter.RequestDevice(
1458
2953
  &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
1459
2954
  [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
@@ -1474,6 +2969,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
1474
2969
  ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
1475
2970
  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
1476
2971
  wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2972
+
2973
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2974
+ // Initialize buffer pool for timestamp queries (profiling)
2975
+ ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
2976
+ WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2977
+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2978
+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2979
+ #endif
2980
+
1477
2981
  ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
1478
2982
  wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
1479
2983
  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
@@ -1484,8 +2988,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
1484
2988
  ggml_webgpu_init_get_rows_pipeline(ctx);
1485
2989
  ggml_webgpu_init_cpy_pipeline(ctx);
1486
2990
  ggml_webgpu_init_add_pipeline(ctx);
2991
+ ggml_webgpu_init_sub_pipeline(ctx);
1487
2992
  ggml_webgpu_init_mul_pipeline(ctx);
2993
+ ggml_webgpu_init_div_pipeline(ctx);
1488
2994
  ggml_webgpu_init_rms_norm_pipeline(ctx);
2995
+ ggml_webgpu_init_rope_pipeline(ctx);
2996
+ ggml_webgpu_init_glu_pipeline(ctx);
2997
+ ggml_webgpu_init_scale_pipeline(ctx);
2998
+ ggml_webgpu_init_soft_max_pipeline(ctx);
2999
+ ggml_webgpu_init_unary_pipeline(ctx);
1489
3000
 
1490
3001
  #ifdef GGML_WEBGPU_DEBUG
1491
3002
  // Initialize debug buffers
@@ -1512,6 +3023,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
1512
3023
  /* .reg = */ reg,
1513
3024
  /* .context = */ &device_ctx,
1514
3025
  };
3026
+
3027
+ WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx);
1515
3028
  return &device;
1516
3029
  }
1517
3030
 
@@ -1538,7 +3051,23 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
1538
3051
  std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
1539
3052
  instance_descriptor.requiredFeatures = instance_features.data();
1540
3053
  instance_descriptor.requiredFeatureCount = instance_features.size();
1541
- webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
3054
+
3055
+ #ifndef __EMSCRIPTEN__
3056
+ const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
3057
+ wgpu::DawnTogglesDescriptor instanceTogglesDesc;
3058
+ instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
3059
+ instanceTogglesDesc.enabledToggleCount = 1;
3060
+ instance_descriptor.nextInChain = &instanceTogglesDesc;
3061
+ #endif
3062
+
3063
+ webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
3064
+
3065
+ #ifdef __EMSCRIPTEN__
3066
+ if (webgpu_ctx->instance == nullptr) {
3067
+ GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
3068
+ return nullptr;
3069
+ }
3070
+ #endif
1542
3071
  GGML_ASSERT(webgpu_ctx->instance != nullptr);
1543
3072
 
1544
3073
  static ggml_backend_reg reg = {