whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -0,0 +1,3087 @@
1
+ /*
2
+ WebGPU backend implementation.
3
+ Note: Use ClangFormat to format this file.
4
+ */
5
+
6
+ #include "ggml-webgpu.h"
7
+
8
+ #include "ggml-backend-impl.h"
9
+ #include "ggml-impl.h"
10
+ #include "ggml-webgpu-shader-lib.hpp"
11
+ #include "ggml-wgsl-shaders.hpp"
12
+ #include "pre_wgsl.hpp"
13
+
14
+ #ifdef __EMSCRIPTEN__
15
+ # include <emscripten/emscripten.h>
16
+ #endif
17
+
18
+ #include <webgpu/webgpu_cpp.h>
19
+
20
+ #include <atomic>
21
+ #include <condition_variable>
22
+ #include <cstdint>
23
+ #include <cstring>
24
+ #include <iostream>
25
+ #include <map>
26
+ #include <mutex>
27
+ #include <optional>
28
+ #include <string>
29
+ #include <vector>
30
+
31
+ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
32
+ #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
33
+
34
+ #ifdef GGML_WEBGPU_DEBUG
35
+ # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
36
+ # define WEBGPU_DEBUG_BUF_ELEMS 512
37
+ #else
38
+ # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
39
+ #endif // GGML_WEBGPU_DEBUG
40
+
41
+ #ifdef GGML_WEBGPU_CPU_PROFILE
42
+ // total timing (aggregated)
43
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
44
+
45
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
46
+ auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
47
+ double cpu_total_time_##id = \
48
+ std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
49
+ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
50
+
51
+ // fine-grained timing (not included in totals)
52
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
53
+
54
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
55
+ auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
56
+ double cpu_detail_time_##id = \
57
+ std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
58
+ (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
59
+ #else
60
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id)
61
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
62
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id)
63
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
64
+ #endif // GGML_WEBGPU_CPU_PROFILE
65
+
66
+ #ifdef GGML_WEBGPU_GPU_PROFILE
67
+ # define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
68
+ # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
69
+ #endif
70
+
71
+ /* Constants */
72
+
73
+ // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
74
+ #define WEBGPU_MAX_WG_SIZE 288
75
+
76
+ #define WEBGPU_MUL_MAT_WG_SIZE 256
77
+ #define WEBGPU_NUM_PARAM_BUFS 32u
78
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
79
+ #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
80
+ // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
81
+ #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
82
+ #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
83
+ #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
84
+ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
85
+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
86
+
87
+ // For operations which process a row in parallel, this seems like a reasonable default
88
+ #define WEBGPU_ROW_SPLIT_WG_SIZE 64
89
+
90
+ // Matrix multiplication parameters
91
+
92
+ // Register tiling parameters
93
+ #define WEBGPU_MUL_MAT_TILE_M 8
94
+ #define WEBGPU_MUL_MAT_TILE_N 8
95
+ #define WEBGPU_MUL_MAT_WG_SIZE_M 8
96
+ #define WEBGPU_MUL_MAT_WG_SIZE_N 8
97
+ #define WEBGPU_MUL_MAT_TILE_K 32
98
+
99
+ // Subgroup matrix parameters
100
+ // The number of subgroups in the M dimension
101
+ #define WEBGPU_MUL_MAT_SUBGROUP_M 2
102
+ // The number of subgroups in the N dimension
103
+ #define WEBGPU_MUL_MAT_SUBGROUP_N 2
104
+ // The number of subgroup matrices each subgroup accumulates over
105
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
106
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
107
+
108
+ // Matrix-vector multiplication parameters
109
+ #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
110
+ // Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
111
+ #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
112
+ #define WEBGPU_MUL_MAT_VEC_TILE_K 256
113
+
114
+ /* End Constants */
115
+
116
+ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
117
+ static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
118
+
119
+ // Always returns the base offset of a tensor, regardless of views.
120
+ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
121
+ if (tensor->view_src) {
122
+ return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
123
+ }
124
+ return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
125
+ }
126
+
127
+ /* Struct definitions */
128
+
129
+ // Forward reference
130
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
131
+ wgpu::Buffer & buffer,
132
+ size_t size,
133
+ wgpu::BufferUsage usage,
134
+ const char * label);
135
+
136
+ struct webgpu_pool_bufs {
137
+ wgpu::Buffer host_buf;
138
+ wgpu::Buffer dev_buf;
139
+ };
140
+
141
+ // The futures to wait on for a single queue submission
142
+ struct webgpu_submission_futures {
143
+ std::vector<wgpu::FutureWaitInfo> futures;
144
+ };
145
+
146
+ // Holds a pool of parameter buffers for WebGPU operations
147
+ struct webgpu_buf_pool {
148
+ std::vector<webgpu_pool_bufs> free;
149
+
150
+ std::mutex mutex;
151
+
152
+ std::condition_variable cv;
153
+
154
+ void init(wgpu::Device device,
155
+ int num_bufs,
156
+ size_t buf_size,
157
+ wgpu::BufferUsage dev_buf_usage,
158
+ wgpu::BufferUsage host_buf_usage) {
159
+ for (int i = 0; i < num_bufs; i++) {
160
+ wgpu::Buffer host_buf;
161
+ wgpu::Buffer dev_buf;
162
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
163
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
164
+ free.push_back({ host_buf, dev_buf });
165
+ }
166
+ }
167
+
168
+ webgpu_pool_bufs alloc_bufs() {
169
+ std::unique_lock<std::mutex> lock(mutex);
170
+ cv.wait(lock, [this] { return !free.empty(); });
171
+ webgpu_pool_bufs bufs = free.back();
172
+ free.pop_back();
173
+ return bufs;
174
+ }
175
+
176
+ void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
177
+ std::lock_guard<std::mutex> lock(mutex);
178
+ free.insert(free.end(), bufs.begin(), bufs.end());
179
+ cv.notify_all();
180
+ }
181
+
182
+ void cleanup() {
183
+ std::lock_guard<std::mutex> lock(mutex);
184
+ for (auto & bufs : free) {
185
+ bufs.host_buf.Destroy();
186
+ bufs.dev_buf.Destroy();
187
+ }
188
+ free.clear();
189
+ }
190
+ };
191
+
192
+ #ifdef GGML_WEBGPU_GPU_PROFILE
193
+ struct webgpu_gpu_profile_bufs {
194
+ wgpu::Buffer host_buf;
195
+ wgpu::Buffer dev_buf;
196
+ wgpu::QuerySet query_set;
197
+ };
198
+
199
+ // Holds a pool of parameter buffers for WebGPU operations
200
+ struct webgpu_gpu_profile_buf_pool {
201
+ std::vector<webgpu_gpu_profile_bufs> free;
202
+
203
+ std::mutex mutex;
204
+
205
+ std::condition_variable cv;
206
+
207
+ void init(wgpu::Device device,
208
+ int num_bufs,
209
+ size_t buf_size,
210
+ wgpu::BufferUsage dev_buf_usage,
211
+ wgpu::BufferUsage host_buf_usage) {
212
+ for (int i = 0; i < num_bufs; i++) {
213
+ wgpu::Buffer host_buf;
214
+ wgpu::Buffer dev_buf;
215
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
216
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
217
+ // Create a query set for 2 timestamps
218
+ wgpu::QuerySetDescriptor ts_query_set_desc = {};
219
+
220
+ ts_query_set_desc.type = wgpu::QueryType::Timestamp;
221
+ ts_query_set_desc.count = 2;
222
+ wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
223
+
224
+ free.push_back({ host_buf, dev_buf, ts_query_set });
225
+ }
226
+ }
227
+
228
+ webgpu_gpu_profile_bufs alloc_bufs() {
229
+ std::unique_lock<std::mutex> lock(mutex);
230
+ cv.wait(lock, [this] { return !free.empty(); });
231
+ webgpu_gpu_profile_bufs bufs = free.back();
232
+ free.pop_back();
233
+ return bufs;
234
+ }
235
+
236
+ void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
237
+ std::lock_guard<std::mutex> lock(mutex);
238
+ free.insert(free.end(), bufs.begin(), bufs.end());
239
+ cv.notify_all();
240
+ }
241
+
242
+ void cleanup() {
243
+ std::lock_guard<std::mutex> lock(mutex);
244
+ for (auto & bufs : free) {
245
+ bufs.host_buf.Destroy();
246
+ bufs.dev_buf.Destroy();
247
+ bufs.query_set.Destroy();
248
+ }
249
+ free.clear();
250
+ }
251
+ };
252
+ #endif
253
+
254
+ struct webgpu_pipeline {
255
+ wgpu::ComputePipeline pipeline;
256
+ std::string name;
257
+ void * context = nullptr;
258
+ };
259
+
260
+ struct webgpu_command {
261
+ wgpu::CommandBuffer commands;
262
+ webgpu_pool_bufs params_bufs;
263
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs;
264
+ #ifdef GGML_WEBGPU_GPU_PROFILE
265
+ webgpu_gpu_profile_bufs timestamp_query_bufs;
266
+ std::string pipeline_name;
267
+ #endif
268
+ };
269
+
270
+ struct flash_attn_pipeline_key {
271
+ int q_type;
272
+ int kv_type;
273
+ int dst_type;
274
+ uint32_t head_dim_qk;
275
+ uint32_t head_dim_v;
276
+ bool kv_direct;
277
+ bool has_mask;
278
+ bool has_sinks;
279
+ bool uses_logit_softcap;
280
+
281
+ bool operator==(const flash_attn_pipeline_key & other) const {
282
+ return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
283
+ head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
284
+ has_mask == other.has_mask && has_sinks == other.has_sinks &&
285
+ uses_logit_softcap == other.uses_logit_softcap;
286
+ }
287
+ };
288
+
289
+ // Same hash combine function as in boost
290
+ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
291
+ seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
292
+ }
293
+
294
+ struct flash_attn_pipeline_key_hash {
295
+ size_t operator()(const flash_attn_pipeline_key & key) const {
296
+ size_t seed = 0;
297
+ ggml_webgpu_hash_combine(seed, key.q_type);
298
+ ggml_webgpu_hash_combine(seed, key.kv_type);
299
+ ggml_webgpu_hash_combine(seed, key.dst_type);
300
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
301
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
302
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
303
+ ggml_webgpu_hash_combine(seed, key.has_mask);
304
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
305
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
306
+ return seed;
307
+ }
308
+ };
309
+
310
+ // All the base objects needed to run operations on a WebGPU device
311
+ struct webgpu_context_struct {
312
+ wgpu::Instance instance;
313
+ wgpu::Adapter adapter;
314
+ wgpu::Device device;
315
+ wgpu::Queue queue;
316
+ wgpu::Limits limits;
317
+
318
+ uint32_t max_subgroup_size;
319
+
320
+ bool supports_subgroup_matrix = false;
321
+ uint32_t sg_mat_m;
322
+ uint32_t sg_mat_n;
323
+ uint32_t sg_mat_k;
324
+
325
+ std::recursive_mutex mutex;
326
+ std::atomic_uint inflight_threads = 0;
327
+
328
+ webgpu_buf_pool param_buf_pool;
329
+ webgpu_buf_pool set_rows_error_buf_pool;
330
+
331
+ pre_wgsl::Preprocessor p;
332
+
333
+ std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
334
+
335
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
336
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
337
+ mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
338
+
339
+ std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
340
+
341
+ std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
342
+ std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
343
+
344
+ std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
345
+ std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
346
+ std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
347
+ std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
348
+ std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
349
+
350
+ std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
351
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
352
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
353
+ std::map<int, webgpu_pipeline> scale_pipelines; // inplace
354
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
355
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> unary_pipelines; // unary_op, type, inplace
356
+
357
+ size_t memset_bytes_per_thread;
358
+
359
+ // Staging buffer for reading data from the GPU
360
+ wgpu::Buffer get_tensor_staging_buf;
361
+
362
+ #ifdef GGML_WEBGPU_DEBUG
363
+ wgpu::Buffer debug_host_buf;
364
+ wgpu::Buffer debug_dev_buf;
365
+ #endif
366
+
367
+ #ifdef GGML_WEBGPU_CPU_PROFILE
368
+ // Profiling: labeled CPU time in ms (total)
369
+ std::unordered_map<std::string, double> cpu_time_ms;
370
+ // Profiling: detailed CPU time in ms
371
+ std::unordered_map<std::string, double> cpu_detail_ms;
372
+ #endif
373
+
374
+ #ifdef GGML_WEBGPU_GPU_PROFILE
375
+ // Profiling: per-shader GPU time in ms
376
+ std::unordered_map<std::string, double> shader_gpu_time_ms;
377
+ // Profiling: pool of timestamp query buffers (one per operation)
378
+ webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
379
+ #endif
380
+ };
381
+
382
+ typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
383
+
384
+ struct ggml_backend_webgpu_reg_context {
385
+ webgpu_context webgpu_ctx;
386
+ size_t device_count;
387
+ const char * name;
388
+ };
389
+
390
+ struct ggml_backend_webgpu_device_context {
391
+ webgpu_context webgpu_ctx;
392
+ std::string device_name;
393
+ std::string device_desc;
394
+ };
395
+
396
+ struct ggml_backend_webgpu_context {
397
+ webgpu_context webgpu_ctx;
398
+ std::string name;
399
+ };
400
+
401
+ struct ggml_backend_webgpu_buffer_context {
402
+ webgpu_context webgpu_ctx;
403
+ wgpu::Buffer buffer;
404
+ std::string label;
405
+
406
+ ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
407
+ webgpu_ctx(std::move(ctx)),
408
+ buffer(std::move(buf)),
409
+ label(std::move(lbl)) {}
410
+ };
411
+
412
+ /* WebGPU object initializations */
413
+
414
+ // Process a WGSL shader string, replacing tokens of the form {{KEY}} with
415
+ // the corresponding values provided in `repls`.
416
+ static std::string ggml_webgpu_process_shader_repls(const char * src,
417
+ const std::map<std::string, std::string> & repls) {
418
+ if (!src) {
419
+ return std::string();
420
+ }
421
+ std::string s = src;
422
+ for (const auto & kv : repls) {
423
+ std::string token = "{{" + kv.first + "}}";
424
+ size_t pos = 0;
425
+ while ((pos = s.find(token, pos)) != std::string::npos) {
426
+ s.replace(pos, token.length(), kv.second);
427
+ pos += kv.second.length();
428
+ }
429
+ }
430
+ return s;
431
+ }
432
+
433
+ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
434
+ const char * shader_code,
435
+ const char * label,
436
+ const std::vector<wgpu::ConstantEntry> & constants = {}) {
437
+ wgpu::ShaderSourceWGSL shader_source;
438
+ shader_source.code = shader_code;
439
+
440
+ wgpu::ShaderModuleDescriptor shader_desc;
441
+ shader_desc.nextInChain = &shader_source;
442
+
443
+ wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
444
+
445
+ wgpu::ComputePipelineDescriptor pipeline_desc;
446
+ pipeline_desc.label = label;
447
+ pipeline_desc.compute.module = shader_module;
448
+ pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
449
+ pipeline_desc.layout = nullptr; // nullptr means auto layout
450
+ if (constants.size() > 0) {
451
+ pipeline_desc.compute.constants = constants.data();
452
+ pipeline_desc.compute.constantCount = constants.size();
453
+ }
454
+ return { device.CreateComputePipeline(&pipeline_desc), label };
455
+ }
456
+
457
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
458
+ wgpu::Buffer & buffer,
459
+ size_t size,
460
+ wgpu::BufferUsage usage,
461
+ const char * label) {
462
+ wgpu::BufferDescriptor buffer_desc;
463
+ buffer_desc.size = size;
464
+ buffer_desc.usage = usage;
465
+ buffer_desc.label = label;
466
+ buffer_desc.mappedAtCreation = false;
467
+
468
+ // TODO: error handling
469
+ buffer = device.CreateBuffer(&buffer_desc);
470
+ }
471
+
472
+ /** End WebGPU object initializations */
473
+
474
+ /** WebGPU Actions */
475
+
476
+ // Wait for the queue to finish processing all submitted work
477
+ static void ggml_backend_webgpu_wait(webgpu_context & ctx,
478
+ std::vector<webgpu_submission_futures> & futures,
479
+ bool block = true) {
480
+ // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
481
+ // inflight_max may be 0, meaning that we must wait on all futures.
482
+ uint64_t timeout_ms = block ? UINT64_MAX : 0;
483
+ uint32_t inflight_threads = ctx->inflight_threads;
484
+ uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
485
+ while (futures.size() >= inflight_max && futures.size() > 0) {
486
+ ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
487
+ futures.erase(futures.begin());
488
+ }
489
+ size_t i = 0;
490
+ while (i < futures.size()) {
491
+ auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
492
+ switch (waitStatus) {
493
+ case wgpu::WaitStatus::Success:
494
+ futures.erase(futures.begin() + i);
495
+ break;
496
+ case wgpu::WaitStatus::TimedOut:
497
+ i++;
498
+ break;
499
+ case wgpu::WaitStatus::Error:
500
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
501
+ break;
502
+ default:
503
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
504
+ break;
505
+ }
506
+ }
507
+ }
508
+
509
+ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
510
+ wgpu::Buffer & buffer,
511
+ wgpu::MapMode mode,
512
+ size_t offset,
513
+ size_t size) {
514
+ ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
515
+ [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
516
+ if (status != wgpu::MapAsyncStatus::Success) {
517
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
518
+ message.data);
519
+ }
520
+ }),
521
+ UINT64_MAX);
522
+ }
523
+
524
+ #ifdef GGML_WEBGPU_DEBUG
525
+ // This function adds debugging information to shaders, as WebGPU does not support printing directly.
526
+ // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
527
+ // debug statements in the shader, and then call this function after encoding the commands and submitting them.
528
+ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
529
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
530
+ encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
531
+ wgpu::CommandBuffer commands = encoder.Finish();
532
+ ctx->queue.Submit(1, &commands);
533
+ ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
534
+ const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
535
+ std::cout << "debug[0]: " << debug_data[0] << "\n";
536
+ ctx->debug_host_buf.Unmap();
537
+ }
538
+ #endif
539
+
540
+ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) {
541
+ std::vector<wgpu::CommandBuffer> command_buffers;
542
+ std::vector<webgpu_pool_bufs> params_bufs;
543
+ std::vector<webgpu_pool_bufs> set_rows_error_bufs;
544
+ #ifdef GGML_WEBGPU_GPU_PROFILE
545
+ std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
546
+ #endif
547
+
548
+ for (const auto & command : commands) {
549
+ command_buffers.push_back(command.commands);
550
+ params_bufs.push_back(command.params_bufs);
551
+ if (command.set_rows_error_bufs) {
552
+ set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
553
+ }
554
+ }
555
+ ctx->queue.Submit(command_buffers.size(), command_buffers.data());
556
+
557
+ std::vector<wgpu::FutureWaitInfo> futures;
558
+
559
+ wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
560
+ wgpu::CallbackMode::AllowSpontaneous,
561
+ [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
562
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
563
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
564
+ }
565
+ // Free the staged buffers
566
+ ctx->param_buf_pool.free_bufs({ params_bufs });
567
+ });
568
+ futures.push_back({ p_f });
569
+
570
+ for (const auto & bufs : set_rows_error_bufs) {
571
+ wgpu::Future f = bufs.host_buf.MapAsync(
572
+ wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
573
+ [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
574
+ if (status != wgpu::MapAsyncStatus::Success) {
575
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
576
+ } else {
577
+ const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
578
+ if (*error_data) {
579
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
580
+ }
581
+ // We can't unmap in here due to WebGPU reentrancy limitations.
582
+ ctx->set_rows_error_buf_pool.free_bufs({ bufs });
583
+ }
584
+ });
585
+ futures.push_back({ f });
586
+ }
587
+
588
+ #ifdef GGML_WEBGPU_GPU_PROFILE
589
+ for (const auto & command : commands) {
590
+ auto label = command.pipeline_name;
591
+ auto ts_bufs = command.timestamp_query_bufs;
592
+
593
+ wgpu::Future f = ts_bufs.host_buf.MapAsync(
594
+ wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
595
+ [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
596
+ if (status != wgpu::MapAsyncStatus::Success) {
597
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
598
+ } else {
599
+ const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
600
+ // WebGPU timestamps are in ns; convert to ms
601
+ double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
602
+ ctx->shader_gpu_time_ms[label] += elapsed_ms;
603
+ // We can't unmap in here due to WebGPU reentrancy limitations.
604
+ ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
605
+ }
606
+ });
607
+ futures.push_back({ f });
608
+ }
609
+ #endif
610
+ return { futures };
611
+ }
612
+
613
+ static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
614
+ webgpu_pipeline & pipeline,
615
+ std::vector<uint32_t> params,
616
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
617
+ uint32_t wg_x,
618
+ uint32_t wg_y = 1,
619
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
620
+ webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
621
+
622
+ ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
623
+ uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
624
+ for (size_t i = 0; i < params.size(); i++) {
625
+ _params[i] = params[i];
626
+ };
627
+
628
+ params_bufs.host_buf.Unmap();
629
+
630
+ uint32_t params_bufs_binding_num = bind_group_entries.size();
631
+ bind_group_entries.push_back({ .binding = params_bufs_binding_num,
632
+ .buffer = params_bufs.dev_buf,
633
+ .offset = 0,
634
+ .size = params_bufs.dev_buf.GetSize() });
635
+
636
+ wgpu::BindGroupDescriptor bind_group_desc;
637
+ bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0);
638
+ bind_group_desc.entryCount = bind_group_entries.size();
639
+ bind_group_desc.entries = bind_group_entries.data();
640
+ bind_group_desc.label = pipeline.name.c_str();
641
+ wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
642
+
643
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
644
+ encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
645
+
646
+ #ifdef GGML_WEBGPU_GPU_PROFILE
647
+ // --- Profiling: GPU timestamp queries ---
648
+ // Allocate a timestamp query buffer (2 timestamps: start/end)
649
+ webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
650
+ if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
651
+ ts_bufs.host_buf.Unmap();
652
+ }
653
+
654
+ wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
655
+ .beginningOfPassWriteIndex = 0,
656
+ .endOfPassWriteIndex = 1 };
657
+ wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
658
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
659
+ #else
660
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
661
+ #endif
662
+ pass.SetPipeline(pipeline.pipeline);
663
+ pass.SetBindGroup(0, bind_group);
664
+ pass.DispatchWorkgroups(wg_x, wg_y, 1);
665
+ pass.End();
666
+
667
+ #ifdef GGML_WEBGPU_GPU_PROFILE
668
+ // Resolve the query set into the device buffer
669
+ encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
670
+ encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
671
+ #endif
672
+
673
+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
674
+ if (set_rows_error_bufs) {
675
+ encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
676
+ set_rows_error_bufs->host_buf.GetSize());
677
+ }
678
+
679
+ wgpu::CommandBuffer commands = encoder.Finish();
680
+ webgpu_command result = {};
681
+ result.commands = commands;
682
+ result.params_bufs = params_bufs;
683
+ result.set_rows_error_bufs = set_rows_error_bufs;
684
+ #ifdef GGML_WEBGPU_GPU_PROFILE
685
+ result.timestamp_query_bufs = ts_bufs;
686
+ result.pipeline_name = pipeline.name;
687
+ #endif
688
+ return result;
689
+ }
690
+
691
+ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
692
+ wgpu::Buffer & buf,
693
+ uint32_t value,
694
+ size_t offset,
695
+ size_t size) {
696
+ std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
697
+ std::vector<wgpu::BindGroupEntry> entries = {
698
+ { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
699
+ };
700
+ size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread;
701
+ uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
702
+
703
+ webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x);
704
+ std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
705
+ ggml_backend_webgpu_wait(ctx, futures);
706
+ }
707
+
708
+ /** End WebGPU Actions */
709
+
710
+ /** GGML Backend Interface */
711
+
712
+ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
713
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
714
+ return ctx->name.c_str();
715
+ }
716
+
717
+ // TODO: implement proper cleanup
718
+ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
719
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
720
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
721
+
722
+ #ifdef GGML_WEBGPU_CPU_PROFILE
723
+ std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
724
+ double total_cpu = 0.0;
725
+ for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
726
+ total_cpu += kv.second;
727
+ }
728
+ std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
729
+ std::cout << "ggml_webgpu: cpu breakdown:\n";
730
+ for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
731
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
732
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
733
+ }
734
+ if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) {
735
+ std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
736
+ }
737
+ for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) {
738
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
739
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
740
+ }
741
+ #endif
742
+
743
+ #ifdef GGML_WEBGPU_GPU_PROFILE
744
+ std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
745
+ double total_gpu = 0.0;
746
+ for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
747
+ total_gpu += kv.second;
748
+ }
749
+ std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
750
+ std::cout << "\nggml_webgpu: gpu breakdown:\n";
751
+ for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
752
+ double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
753
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
754
+ }
755
+ #endif
756
+
757
+ #if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
758
+ std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
759
+ #endif
760
+
761
+ #if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
762
+ GGML_UNUSED(ctx);
763
+ #endif
764
+ }
765
+
766
+ static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
767
+ return webgpu_tensor_offset(tensor) + tensor->view_offs;
768
+ }
769
+
770
+ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
771
+ ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
772
+ return ctx->buffer;
773
+ }
774
+
775
+ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
776
+ size_t offset = ggml_webgpu_tensor_offset(t);
777
+ return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
778
+ }
779
+
780
+ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
781
+ size_t offset = ggml_webgpu_tensor_offset(t);
782
+ return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
783
+ }
784
+
785
+ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
786
+ return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
787
+ }
788
+
789
+ // Used to determine if two tensors are the same for in-place operations
790
+ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
791
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
792
+ (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
793
+ }
794
+
795
+ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
796
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
797
+
798
+ std::vector<uint32_t> params = {
799
+ ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
800
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
801
+ // Convert byte-strides to element-strides
802
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
803
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
804
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
805
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
806
+ // Logical shapes
807
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
808
+ (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
809
+ };
810
+
811
+ std::vector<wgpu::BindGroupEntry> entries = {
812
+ { .binding = 0,
813
+ .buffer = ggml_webgpu_tensor_buf(src),
814
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
815
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
816
+ { .binding = 1,
817
+ .buffer = ggml_webgpu_tensor_buf(dst),
818
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
819
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
820
+ };
821
+
822
+ uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
823
+ return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
824
+ }
825
+
826
+ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
827
+ ggml_tensor * src,
828
+ ggml_tensor * idx,
829
+ ggml_tensor * dst) {
830
+ // For set rows specifically, we need to check if src and idx are empty tensors.
831
+ if (ggml_is_empty(src) || ggml_is_empty(idx)) {
832
+ return std::nullopt;
833
+ }
834
+
835
+ webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
836
+ if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
837
+ error_bufs.host_buf.Unmap();
838
+ }
839
+
840
+ std::vector<uint32_t> params = {
841
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
842
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
843
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
844
+ // Convert byte-strides to element-strides
845
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
846
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
847
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
848
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
849
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
850
+ // Shape of src
851
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
852
+ // Shape of idx
853
+ (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
854
+ };
855
+
856
+ std::vector<wgpu::BindGroupEntry> entries = {
857
+ { .binding = 0,
858
+ .buffer = ggml_webgpu_tensor_buf(src),
859
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
860
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
861
+ { .binding = 1,
862
+ .buffer = ggml_webgpu_tensor_buf(idx),
863
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
864
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
865
+ { .binding = 2,
866
+ .buffer = ggml_webgpu_tensor_buf(dst),
867
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
868
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
869
+ { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
870
+ };
871
+
872
+ int vectorized = src->ne[0] % 4 == 0;
873
+ webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized];
874
+ uint32_t threads;
875
+ if (vectorized) {
876
+ threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
877
+ } else {
878
+ threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
879
+ }
880
+
881
+ uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE);
882
+
883
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
884
+ }
885
+
886
+ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
887
+ ggml_tensor * src,
888
+ ggml_tensor * idx,
889
+ ggml_tensor * dst) {
890
+ std::vector<uint32_t> params = {
891
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
892
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
893
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
894
+ // Convert byte-strides to element-strides
895
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
896
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
897
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
898
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
899
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
900
+ // Shape of dst
901
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
902
+ // Shape of idx
903
+ (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
904
+ };
905
+
906
+ std::vector<wgpu::BindGroupEntry> entries = {
907
+ { .binding = 0,
908
+ .buffer = ggml_webgpu_tensor_buf(src),
909
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
910
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
911
+ { .binding = 1,
912
+ .buffer = ggml_webgpu_tensor_buf(idx),
913
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
914
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
915
+ { .binding = 2,
916
+ .buffer = ggml_webgpu_tensor_buf(dst),
917
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
918
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
919
+ };
920
+
921
+ uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
922
+
923
+ uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
924
+ webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
925
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
926
+ }
927
+
928
+ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
929
+ ggml_tensor * src0,
930
+ ggml_tensor * src1,
931
+ ggml_tensor * dst) {
932
+ std::vector<uint32_t> params = {
933
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
934
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
935
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
936
+ (uint32_t) dst->ne[0], // number of rows in result (M, transposed)
937
+ (uint32_t) dst->ne[1], // number of columns in result (N)
938
+ (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
939
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
940
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
941
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
942
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
943
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
944
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
945
+ (uint32_t) src0->ne[2], // batch size in dimension 2
946
+ (uint32_t) src0->ne[3], // batch size in dimension 3
947
+ (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
948
+ (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
949
+ };
950
+
951
+ std::vector<wgpu::BindGroupEntry> entries = {
952
+ { .binding = 0,
953
+ .buffer = ggml_webgpu_tensor_buf(src0),
954
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
955
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
956
+ { .binding = 1,
957
+ .buffer = ggml_webgpu_tensor_buf(src1),
958
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
959
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
960
+ { .binding = 2,
961
+ .buffer = ggml_webgpu_tensor_buf(dst),
962
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
963
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
964
+ };
965
+
966
+ webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
967
+
968
+ uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
969
+ uint32_t wg_y = 1;
970
+
971
+ bool use_fast = false;
972
+ switch (src1->type) {
973
+ case GGML_TYPE_F16:
974
+ use_fast = (src0->type == GGML_TYPE_F16);
975
+ break;
976
+ case GGML_TYPE_F32:
977
+ switch (src0->type) {
978
+ case GGML_TYPE_F32:
979
+ case GGML_TYPE_F16:
980
+ case GGML_TYPE_Q4_0:
981
+ use_fast = true;
982
+ break;
983
+ default:
984
+ break;
985
+ }
986
+ break;
987
+ default:
988
+ break;
989
+ }
990
+
991
+ if (use_fast) {
992
+ int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
993
+ if (dst->ne[1] == 1) {
994
+ // We don't support vectorized mul_mat_vec for quantized types
995
+ vectorized = vectorized && (src0->type < 2);
996
+ pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
997
+ uint32_t batches = dst->ne[2] * dst->ne[3];
998
+ uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
999
+ uint32_t total_wg = output_groups * batches;
1000
+ wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
1001
+ wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension);
1002
+ } else {
1003
+ pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
1004
+ uint32_t wg_m;
1005
+ uint32_t wg_n;
1006
+ #ifndef __EMSCRIPTEN__
1007
+ if (ctx->supports_subgroup_matrix) {
1008
+ // The total number of subgroups/workgroups needed per matrix.
1009
+ uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
1010
+ wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1011
+ uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
1012
+ wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1013
+ } else {
1014
+ #endif
1015
+ uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
1016
+ uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
1017
+ wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1018
+ wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
1019
+ #ifndef __EMSCRIPTEN__
1020
+ }
1021
+ #endif
1022
+
1023
+ wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1024
+ }
1025
+ }
1026
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
1027
+ }
1028
+
1029
+ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1030
+ ggml_tensor * Q,
1031
+ ggml_tensor * K,
1032
+ ggml_tensor * V,
1033
+ ggml_tensor * mask,
1034
+ ggml_tensor * sinks,
1035
+ ggml_tensor * dst) {
1036
+ float scale = *(float *) dst->op_params;
1037
+ float max_bias;
1038
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1039
+ float logit_softcap;
1040
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
1041
+ if (logit_softcap != 0.0f) {
1042
+ scale /= logit_softcap;
1043
+ }
1044
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
1045
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1046
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1047
+
1048
+ const int has_mask = (mask != nullptr);
1049
+ const int has_sinks = (sinks != nullptr);
1050
+
1051
+ std::vector<uint32_t> params = {
1052
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
1053
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
1054
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
1055
+ has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
1056
+ has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
1057
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1058
+ (uint32_t) Q->ne[2], // number of heads
1059
+ (uint32_t) Q->ne[1], // sequence length (Q)
1060
+ (uint32_t) K->ne[1], // sequence length (K/V)
1061
+ (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
1062
+ (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
1063
+ (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
1064
+ (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
1065
+ (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
1066
+ (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
1067
+ (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
1068
+ (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
1069
+ (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
1070
+ has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
1071
+ (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
1072
+ *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
1073
+ *(uint32_t *) &max_bias,
1074
+ *(uint32_t *) &logit_softcap,
1075
+ *(uint32_t *) &n_head_log2,
1076
+ *(uint32_t *) &m0,
1077
+ *(uint32_t *) &m1
1078
+
1079
+ };
1080
+ std::vector<wgpu::BindGroupEntry> entries = {
1081
+ { .binding = 0,
1082
+ .buffer = ggml_webgpu_tensor_buf(Q),
1083
+ .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
1084
+ .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
1085
+ { .binding = 1,
1086
+ .buffer = ggml_webgpu_tensor_buf(K),
1087
+ .offset = ggml_webgpu_tensor_align_offset(ctx, K),
1088
+ .size = ggml_webgpu_tensor_binding_size(ctx, K) },
1089
+ { .binding = 2,
1090
+ .buffer = ggml_webgpu_tensor_buf(V),
1091
+ .offset = ggml_webgpu_tensor_align_offset(ctx, V),
1092
+ .size = ggml_webgpu_tensor_binding_size(ctx, V) }
1093
+ };
1094
+ uint32_t binding_index = 3;
1095
+ if (has_mask) {
1096
+ entries.push_back({ .binding = binding_index++,
1097
+ .buffer = ggml_webgpu_tensor_buf(mask),
1098
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
1099
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
1100
+ }
1101
+ if (has_sinks) {
1102
+ entries.push_back({ .binding = binding_index++,
1103
+ .buffer = ggml_webgpu_tensor_buf(sinks),
1104
+ .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
1105
+ .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
1106
+ }
1107
+ entries.push_back({ .binding = binding_index++,
1108
+ .buffer = ggml_webgpu_tensor_buf(dst),
1109
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1110
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1111
+
1112
+ bool kv_direct =
1113
+ (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
1114
+
1115
+ flash_attn_pipeline_key key = {
1116
+ .q_type = Q->type,
1117
+ .kv_type = K->type,
1118
+ .dst_type = dst->type,
1119
+ .head_dim_qk = (uint32_t) Q->ne[0],
1120
+ .head_dim_v = (uint32_t) V->ne[0],
1121
+ .kv_direct = kv_direct,
1122
+ .has_mask = static_cast<bool>(has_mask),
1123
+ .has_sinks = static_cast<bool>(has_sinks),
1124
+ .uses_logit_softcap = logit_softcap != 0.0f,
1125
+ };
1126
+
1127
+ webgpu_pipeline pipeline;
1128
+ ggml_webgpu_flash_attn_shader_decisions decisions = {};
1129
+
1130
+ auto it = ctx->flash_attn_pipelines.find(key);
1131
+ if (it != ctx->flash_attn_pipelines.end()) {
1132
+ pipeline = it->second;
1133
+ decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
1134
+ } else {
1135
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
1136
+ it = ctx->flash_attn_pipelines.find(key);
1137
+ if (it != ctx->flash_attn_pipelines.end()) {
1138
+ pipeline = it->second;
1139
+ decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
1140
+ } else {
1141
+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
1142
+ .head_dim_qk = (uint32_t) Q->ne[0],
1143
+ .head_dim_v = (uint32_t) V->ne[0],
1144
+ .kv_direct = kv_direct,
1145
+ .has_mask = static_cast<bool>(has_mask),
1146
+ .has_sinks = static_cast<bool>(has_sinks),
1147
+ .uses_logit_softcap = logit_softcap != 0.0f,
1148
+ .sg_mat_m = ctx->sg_mat_m,
1149
+ .sg_mat_n = ctx->sg_mat_n,
1150
+ .sg_mat_k = ctx->sg_mat_k,
1151
+ .wg_mem_limit_bytes =
1152
+ ctx->limits.maxComputeWorkgroupStorageSize,
1153
+ .max_subgroup_size = ctx->max_subgroup_size };
1154
+
1155
+ ggml_webgpu_processed_shader processed =
1156
+ ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
1157
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1158
+ pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
1159
+ ctx->flash_attn_pipelines.emplace(key, pipeline);
1160
+ decisions = processed.decisions;
1161
+ }
1162
+ }
1163
+
1164
+ uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
1165
+ uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
1166
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1167
+ }
1168
+
1169
+ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1170
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1171
+ ggml_unary_op unary_op = ggml_get_unary_op(dst);
1172
+ uint32_t inplace = ggml_webgpu_tensor_equal(src, dst);
1173
+
1174
+ std::vector<uint32_t> params = {
1175
+ ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1176
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1177
+ // Convert byte-strides to element-strides
1178
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1179
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1180
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1181
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1182
+ // Logical shapes
1183
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
1184
+ (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
1185
+ };
1186
+
1187
+ switch (unary_op) {
1188
+ case GGML_UNARY_OP_XIELU:
1189
+ {
1190
+ // Get float parameters and reinterpret their bit patterns as uint32_t
1191
+ // for passing through the params buffer
1192
+ float alpha_n = ggml_get_op_params_f32(dst, 1);
1193
+ float alpha_p = ggml_get_op_params_f32(dst, 2);
1194
+ float beta = ggml_get_op_params_f32(dst, 3);
1195
+ float eps = ggml_get_op_params_f32(dst, 4);
1196
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1197
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1198
+ params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1199
+ params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1200
+ break;
1201
+ }
1202
+ default:
1203
+ break;
1204
+ }
1205
+
1206
+ std::vector<wgpu::BindGroupEntry> entries = {
1207
+ { .binding = 0,
1208
+ .buffer = ggml_webgpu_tensor_buf(src),
1209
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1210
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1211
+ };
1212
+ if (!inplace) {
1213
+ entries.push_back({ .binding = 1,
1214
+ .buffer = ggml_webgpu_tensor_buf(dst),
1215
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1216
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1217
+ }
1218
+
1219
+ uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
1220
+ return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x);
1221
+ }
1222
+
1223
+ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1224
+ ggml_tensor * src0,
1225
+ ggml_tensor * src1,
1226
+ ggml_tensor * dst,
1227
+ webgpu_pipeline & pipeline,
1228
+ bool inplace) {
1229
+ std::vector<uint32_t> params = {
1230
+ (uint32_t) ggml_nelements(dst),
1231
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1232
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1233
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1234
+ (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
1235
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1236
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1237
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1238
+ (uint32_t) src0->ne[0],
1239
+ (uint32_t) src0->ne[1],
1240
+ (uint32_t) src0->ne[2],
1241
+ (uint32_t) src1->ne[0],
1242
+ (uint32_t) src1->ne[1],
1243
+ (uint32_t) src1->ne[2],
1244
+ (uint32_t) src1->ne[3],
1245
+ };
1246
+
1247
+ std::vector<wgpu::BindGroupEntry> entries = {
1248
+ { .binding = 0,
1249
+ .buffer = ggml_webgpu_tensor_buf(src0),
1250
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1251
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1252
+ { .binding = 1,
1253
+ .buffer = ggml_webgpu_tensor_buf(src1),
1254
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1255
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
1256
+ };
1257
+ if (!inplace) {
1258
+ entries.push_back({ .binding = 2,
1259
+ .buffer = ggml_webgpu_tensor_buf(dst),
1260
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1261
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1262
+ }
1263
+
1264
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1265
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1266
+ }
1267
+
1268
+ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1269
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
1270
+
1271
+ std::vector<uint32_t> params = {
1272
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1273
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1274
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1275
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1276
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1277
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1278
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1279
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1280
+ (uint32_t) src->ne[0],
1281
+ (uint32_t) src->ne[1],
1282
+ (uint32_t) src->ne[2],
1283
+ (uint32_t) src->ne[3],
1284
+ *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
1285
+ };
1286
+
1287
+ std::vector<wgpu::BindGroupEntry> entries = {
1288
+ { .binding = 0,
1289
+ .buffer = ggml_webgpu_tensor_buf(src),
1290
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1291
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1292
+ };
1293
+ if (!inplace) {
1294
+ entries.push_back({ .binding = 1,
1295
+ .buffer = ggml_webgpu_tensor_buf(dst),
1296
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1297
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1298
+ }
1299
+
1300
+ return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src));
1301
+ }
1302
+
1303
+ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1304
+ ggml_tensor * src0,
1305
+ ggml_tensor * src1,
1306
+ ggml_tensor * src2,
1307
+ ggml_tensor * dst) {
1308
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1309
+ const int has_freq_factor = (src2 != nullptr);
1310
+
1311
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1312
+ const int mode = ((int32_t *) dst->op_params)[2];
1313
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1314
+
1315
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1316
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1317
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1318
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1319
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1320
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1321
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1322
+
1323
+ int sections[4];
1324
+ memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
1325
+
1326
+ float theta_scale = powf(freq_base, -2.0f / n_dims);
1327
+
1328
+ float corr_dims[2];
1329
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1330
+
1331
+ std::vector<uint32_t> params = {
1332
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1333
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1334
+ src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1335
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1336
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1337
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1338
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1339
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1340
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1341
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1342
+ (uint32_t) ggml_nelements(src0) / 2,
1343
+ (uint32_t) src0->ne[0],
1344
+ (uint32_t) src0->ne[1],
1345
+ (uint32_t) src0->ne[2],
1346
+ (uint32_t) n_dims,
1347
+ (uint32_t) mode,
1348
+ *(uint32_t *) &theta_scale,
1349
+ *(uint32_t *) &attn_factor,
1350
+ *(uint32_t *) &freq_scale,
1351
+ *(uint32_t *) &ext_factor,
1352
+ *(uint32_t *) &corr_dims[0],
1353
+ *(uint32_t *) &corr_dims[1],
1354
+ (uint32_t) sections[0],
1355
+ (uint32_t) sections[1],
1356
+ (uint32_t) sections[2],
1357
+ (uint32_t) sections[3]
1358
+ };
1359
+
1360
+ std::vector<wgpu::BindGroupEntry> entries = {
1361
+ { .binding = 0,
1362
+ .buffer = ggml_webgpu_tensor_buf(src0),
1363
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1364
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1365
+ { .binding = 1,
1366
+ .buffer = ggml_webgpu_tensor_buf(src1),
1367
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1368
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
1369
+ };
1370
+ uint32_t dst_binding = 2;
1371
+ if (has_freq_factor) {
1372
+ dst_binding = 3;
1373
+ entries.push_back({ .binding = 2,
1374
+ .buffer = ggml_webgpu_tensor_buf(src2),
1375
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1376
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1377
+ }
1378
+ if (!inplace) {
1379
+ entries.push_back({ .binding = dst_binding,
1380
+ .buffer = ggml_webgpu_tensor_buf(dst),
1381
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1382
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1383
+ }
1384
+
1385
+ webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
1386
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1387
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1388
+ }
1389
+
1390
+ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
1391
+ const int split = (src1 != nullptr);
1392
+
1393
+ std::vector<uint32_t> params = {
1394
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1395
+ src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1396
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1397
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1398
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1399
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1400
+ src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
1401
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1402
+ src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
1403
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1404
+ src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
1405
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1406
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1407
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1408
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1409
+ (uint32_t) ggml_nelements(dst),
1410
+ (uint32_t) dst->ne[0],
1411
+ (uint32_t) dst->ne[1],
1412
+ (uint32_t) dst->ne[2],
1413
+ (uint32_t) ((int32_t *) dst->op_params)[1], // swapped
1414
+ *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
1415
+ *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
1416
+ };
1417
+
1418
+ std::vector<wgpu::BindGroupEntry> entries = {
1419
+ { .binding = 0,
1420
+ .buffer = ggml_webgpu_tensor_buf(src0),
1421
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1422
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1423
+ };
1424
+ uint32_t dst_binding = 1;
1425
+ if (split) {
1426
+ dst_binding = 2;
1427
+ entries.push_back({ .binding = 1,
1428
+ .buffer = ggml_webgpu_tensor_buf(src1),
1429
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1430
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1431
+ }
1432
+ entries.push_back({ .binding = dst_binding,
1433
+ .buffer = ggml_webgpu_tensor_buf(dst),
1434
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1435
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1436
+
1437
+ webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
1438
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1439
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
1440
+ }
1441
+
1442
+ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1443
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
1444
+
1445
+ std::vector<uint32_t> params = {
1446
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1447
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1448
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1449
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1450
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1451
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1452
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1453
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1454
+ (uint32_t) ggml_nelements(dst),
1455
+ (uint32_t) src->ne[0],
1456
+ (uint32_t) src->ne[1],
1457
+ (uint32_t) src->ne[2],
1458
+ *(uint32_t *) dst->op_params, // scale
1459
+ *(uint32_t *) &dst->op_params[1] // bias
1460
+ };
1461
+
1462
+ std::vector<wgpu::BindGroupEntry> entries = {
1463
+ { .binding = 0,
1464
+ .buffer = ggml_webgpu_tensor_buf(src),
1465
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1466
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1467
+ };
1468
+ if (!inplace) {
1469
+ entries.push_back({ .binding = 1,
1470
+ .buffer = ggml_webgpu_tensor_buf(dst),
1471
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1472
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1473
+ }
1474
+
1475
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1476
+ return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x);
1477
+ }
1478
+
1479
+ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
1480
+ ggml_tensor * src0,
1481
+ ggml_tensor * src1,
1482
+ ggml_tensor * src2,
1483
+ ggml_tensor * dst) {
1484
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1485
+ const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
1486
+ const int has_sink = (src2 != nullptr);
1487
+ float max_bias;
1488
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1489
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
1490
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1491
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1492
+
1493
+ std::vector<uint32_t> params = {
1494
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1495
+ mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1496
+ has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1497
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1498
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1499
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1500
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1501
+ mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
1502
+ mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
1503
+ mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
1504
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1505
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1506
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1507
+ (uint32_t) ggml_nelements(dst),
1508
+ (uint32_t) src0->ne[0],
1509
+ (uint32_t) src0->ne[1],
1510
+ (uint32_t) src0->ne[2],
1511
+ mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
1512
+ mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
1513
+ *(uint32_t *) dst->op_params, // scale
1514
+ *(uint32_t *) &max_bias,
1515
+ *(uint32_t *) &n_head_log2,
1516
+ *(uint32_t *) &m0,
1517
+ *(uint32_t *) &m1
1518
+ };
1519
+
1520
+ std::vector<wgpu::BindGroupEntry> entries = {
1521
+ { .binding = 0,
1522
+ .buffer = ggml_webgpu_tensor_buf(src0),
1523
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1524
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
1525
+ };
1526
+ uint32_t binding_num = 1;
1527
+ if (mask_type < 2) {
1528
+ entries.push_back({ .binding = binding_num,
1529
+ .buffer = ggml_webgpu_tensor_buf(src1),
1530
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1531
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1532
+ binding_num++;
1533
+ }
1534
+ if (has_sink) {
1535
+ entries.push_back({ .binding = binding_num,
1536
+ .buffer = ggml_webgpu_tensor_buf(src2),
1537
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1538
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1539
+ binding_num++;
1540
+ }
1541
+ if (!inplace) {
1542
+ entries.push_back({ .binding = binding_num,
1543
+ .buffer = ggml_webgpu_tensor_buf(dst),
1544
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1545
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1546
+ }
1547
+
1548
+ return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1549
+ ggml_nrows(dst));
1550
+ }
1551
+
1552
+ // Returns the encoded command, or std::nullopt if the operation is a no-op
1553
+ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
1554
+ if (ggml_is_empty(node)) {
1555
+ return std::nullopt;
1556
+ }
1557
+ WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
1558
+
1559
+ ggml_tensor * src0 = node->src[0];
1560
+ ggml_tensor * src1 = node->src[1];
1561
+ ggml_tensor * src2 = node->src[2];
1562
+
1563
+ switch (node->op) {
1564
+ // no-ops
1565
+ case GGML_OP_NONE:
1566
+ case GGML_OP_VIEW:
1567
+ case GGML_OP_PERMUTE:
1568
+ case GGML_OP_TRANSPOSE:
1569
+ case GGML_OP_RESHAPE:
1570
+ return std::nullopt;
1571
+ case GGML_OP_CPY:
1572
+ case GGML_OP_CONT:
1573
+ return ggml_webgpu_cpy(ctx, src0, node);
1574
+ case GGML_OP_SET_ROWS:
1575
+ return ggml_webgpu_set_rows(ctx, src0, src1, node);
1576
+ case GGML_OP_GET_ROWS:
1577
+ return ggml_webgpu_get_rows(ctx, src0, src1, node);
1578
+ case GGML_OP_MUL_MAT:
1579
+ return ggml_webgpu_mul_mat(ctx, src0, src1, node);
1580
+ case GGML_OP_FLASH_ATTN_EXT:
1581
+ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
1582
+ case GGML_OP_ADD:
1583
+ {
1584
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1585
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
1586
+ }
1587
+ case GGML_OP_SUB:
1588
+ {
1589
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1590
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
1591
+ }
1592
+ case GGML_OP_MUL:
1593
+ {
1594
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1595
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
1596
+ }
1597
+ case GGML_OP_DIV:
1598
+ {
1599
+ int inplace = ggml_webgpu_tensor_equal(src0, node);
1600
+ return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
1601
+ }
1602
+ case GGML_OP_RMS_NORM:
1603
+ return ggml_webgpu_rms_norm(ctx, src0, node);
1604
+ case GGML_OP_ROPE:
1605
+ return ggml_webgpu_rope(ctx, src0, src1, src2, node);
1606
+ case GGML_OP_GLU:
1607
+ return ggml_webgpu_glu(ctx, src0, src1, node);
1608
+ case GGML_OP_SCALE:
1609
+ return ggml_webgpu_scale(ctx, src0, node);
1610
+ case GGML_OP_SOFT_MAX:
1611
+ return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
1612
+ case GGML_OP_UNARY:
1613
+ return ggml_webgpu_unary_op(ctx, src0, node);
1614
+ default:
1615
+ return std::nullopt;
1616
+ }
1617
+ }
1618
+
1619
+ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1620
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
1621
+
1622
+ ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
1623
+ webgpu_context ctx = backend_ctx->webgpu_ctx;
1624
+
1625
+ WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
1626
+
1627
+ ctx->inflight_threads++;
1628
+
1629
+ std::vector<webgpu_command> commands;
1630
+ std::vector<webgpu_submission_futures> futures;
1631
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1632
+ if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
1633
+ commands.push_back(*cmd);
1634
+ }
1635
+ // compute the batch size based on the number of inflight threads
1636
+ uint32_t inflight_threads = ctx->inflight_threads;
1637
+ uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
1638
+ WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1639
+ if (commands.size() >= batch_size) {
1640
+ futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
1641
+ // Process events and check for completed submissions
1642
+ ctx->instance.ProcessEvents();
1643
+ ggml_backend_webgpu_wait(ctx, futures, false);
1644
+ commands.clear();
1645
+ }
1646
+ }
1647
+ if (!commands.empty()) {
1648
+ webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
1649
+ futures.push_back(new_futures);
1650
+ }
1651
+
1652
+ ggml_backend_webgpu_wait(ctx, futures);
1653
+ ctx->inflight_threads--;
1654
+ WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
1655
+ return GGML_STATUS_SUCCESS;
1656
+ }
1657
+
1658
+ static ggml_backend_i ggml_backend_webgpu_i = {
1659
+ /* .get_name = */ ggml_backend_webgpu_name,
1660
+ /* .free = */ ggml_backend_webgpu_free,
1661
+ /* .set_tensor_async = */ NULL,
1662
+ /* .get_tensor_async = */ NULL,
1663
+ /* .cpy_tensor_async = */ NULL,
1664
+ /* .synchronize = */ NULL,
1665
+ /* .graph_plan_create = */ NULL,
1666
+ /* .graph_plan_free = */ NULL,
1667
+ /* .graph_plan_update = */ NULL,
1668
+ /* .graph_plan_compute = */ NULL,
1669
+ /* .graph_compute = */ ggml_backend_webgpu_graph_compute,
1670
+ /* .event_record = */ NULL,
1671
+ /* .event_wait = */ NULL,
1672
+ /* .graph_optimize = */ NULL,
1673
+ };
1674
+
1675
+ /* End GGML Backend Interface */
1676
+
1677
+ /* GGML Backend Buffer Interface */
1678
+
1679
+ static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1680
+ ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
1681
+ ctx->buffer.Destroy();
1682
+ }
1683
+
1684
+ // Returns the "fake" base pointer.
1685
+ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
1686
+ GGML_UNUSED(buffer);
1687
+ return webgpu_ptr_base;
1688
+ }
1689
+
1690
+ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
1691
+ ggml_tensor * tensor,
1692
+ uint8_t value,
1693
+ size_t offset,
1694
+ size_t size) {
1695
+ if (size == 0) {
1696
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
1697
+ return;
1698
+ }
1699
+
1700
+ WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
1701
+
1702
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1703
+
1704
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
1705
+ << ", " << offset << ", " << size << ")");
1706
+
1707
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
1708
+
1709
+ // This is a trick to set all bytes of a u32 to the same 1 byte value.
1710
+ uint32_t val32 = (uint32_t) value * 0x01010101;
1711
+ ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
1712
+ WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx);
1713
+ }
1714
+
1715
+ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
1716
+ ggml_tensor * tensor,
1717
+ const void * data,
1718
+ size_t offset,
1719
+ size_t size) {
1720
+ WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
1721
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1722
+ webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
1723
+
1724
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1725
+ << ", " << offset << ", " << size << ")");
1726
+
1727
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
1728
+
1729
+ webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
1730
+
1731
+ if (size % 4 != 0) {
1732
+ // If size is not a multiple of 4, we need to memset the remaining bytes
1733
+ size_t remaining_size = size % 4;
1734
+
1735
+ // pack the remaining bytes into a uint32_t
1736
+ uint32_t val32 = 0;
1737
+
1738
+ for (size_t i = 0; i < remaining_size; i++) {
1739
+ ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
1740
+ }
1741
+ // memset the remaining bytes
1742
+ ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
1743
+ remaining_size);
1744
+ } else {
1745
+ // wait for WriteBuffer to complete
1746
+ webgpu_ctx->instance.WaitAny(
1747
+ webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
1748
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
1749
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
1750
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
1751
+ std::string(message).c_str());
1752
+ }
1753
+ }),
1754
+ UINT64_MAX);
1755
+ }
1756
+ WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx);
1757
+ }
1758
+
1759
+ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
1760
+ const ggml_tensor * tensor,
1761
+ void * data,
1762
+ size_t offset,
1763
+ size_t size) {
1764
+ WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
1765
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1766
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1767
+ << ", " << offset << ", " << size << ")");
1768
+ webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
1769
+ wgpu::Device device = webgpu_ctx->device;
1770
+
1771
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
1772
+
1773
+ size_t final_size = size;
1774
+ if (size % 4 != 0) {
1775
+ // If size is not a multiple of 4, we need to round it up to the next multiple of 4
1776
+ final_size = size + (4 - (size % 4));
1777
+ }
1778
+
1779
+ std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
1780
+
1781
+ if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
1782
+ // Create a new staging buffer if it doesn't exist or is too small
1783
+ if (webgpu_ctx->get_tensor_staging_buf) {
1784
+ webgpu_ctx->get_tensor_staging_buf.Destroy();
1785
+ }
1786
+ ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
1787
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
1788
+ }
1789
+
1790
+ // Copy the data from the buffer to the staging buffer
1791
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
1792
+ encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
1793
+ wgpu::CommandBuffer commands = encoder.Finish();
1794
+
1795
+ // Submit the command buffer to the queue
1796
+ webgpu_ctx->queue.Submit(1, &commands);
1797
+
1798
+ // Map the staging buffer to read the data
1799
+ ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
1800
+ // Must specify size here since the staging buffer might be larger than the tensor size
1801
+ const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
1802
+
1803
+ // Copy the data from the mapped range to the output buffer
1804
+ std::memcpy(data, mapped_range, size);
1805
+ webgpu_ctx->get_tensor_staging_buf.Unmap();
1806
+ WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx);
1807
+ }
1808
+
1809
+ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1810
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
1811
+ WEBGPU_CPU_PROFILE_TOTAL_START(clear);
1812
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1813
+ ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
1814
+ WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx);
1815
+ }
1816
+
1817
+ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
1818
+ /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
1819
+ /* .get_base = */ ggml_backend_webgpu_buffer_get_base,
1820
+ /* .init_tensor = */ NULL, // TODO: optional, needed?
1821
+ /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
1822
+ /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
1823
+ /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
1824
+ /* .cpy_tensor = */ NULL, // TODO: optional, implement this
1825
+ /* .clear = */ ggml_backend_webgpu_buffer_clear,
1826
+ /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
1827
+ };
1828
+
1829
+ /* End GGML Backend Buffer Interface */
1830
+
1831
+ /* GGML Backend Buffer Type Interface */
1832
+
1833
+ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1834
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1835
+ return ctx->device_name.c_str();
1836
+ }
1837
+
1838
+ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1839
+ size_t size) {
1840
+ static std::atomic<int> buffer_count;
1841
+ int buffer_id = buffer_count++;
1842
+ std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
1843
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
1844
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1845
+
1846
+ wgpu::Buffer buf;
1847
+ ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
1848
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
1849
+ buf_name.c_str());
1850
+
1851
+ ggml_backend_webgpu_buffer_context * buf_ctx =
1852
+ new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
1853
+
1854
+ return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
1855
+ }
1856
+
1857
+ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1858
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1859
+ return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
1860
+ }
1861
+
1862
+ // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
1863
+ static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
1864
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
1865
+ return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
1866
+ }
1867
+
1868
+ /* End GGML Backend Buffer Type Interface */
1869
+
1870
+ /* GGML Backend Device Interface */
1871
+
1872
+ static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
1873
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1874
+ return ctx->device_name.c_str();
1875
+ }
1876
+
1877
+ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
1878
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1879
+ return ctx->device_desc.c_str();
1880
+ }
1881
+
1882
+ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1883
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
1884
+ // TODO: for now, return maxBufferSize as both free and total memory
1885
+ // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
1886
+ uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
1887
+ // If we're on a 32-bit system, clamp to UINTPTR_MAX
1888
+ #if UINTPTR_MAX < UINT64_MAX
1889
+ uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
1890
+ if (max_buffer_size > max_ptr_size) {
1891
+ max_buffer_size = max_ptr_size;
1892
+ }
1893
+ #endif
1894
+ *free = static_cast<size_t>(max_buffer_size);
1895
+ *total = static_cast<size_t>(max_buffer_size);
1896
+ }
1897
+
1898
+ static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
1899
+ GGML_UNUSED(dev);
1900
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
1901
+ }
1902
+
1903
+ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1904
+ props->name = ggml_backend_webgpu_device_get_name(dev);
1905
+ props->description = ggml_backend_webgpu_device_get_description(dev);
1906
+ props->type = ggml_backend_webgpu_device_get_type(dev);
1907
+ ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
1908
+ props->caps = {
1909
+ /* .async = */ false,
1910
+ /* .host_buffer = */ false,
1911
+ /* .buffer_from_host_ptr = */ false,
1912
+ /* .events = */ false,
1913
+ };
1914
+ }
1915
+
1916
+ static ggml_guid_t ggml_backend_webgpu_guid(void) {
1917
+ static const char * guid_str = "__ggml_webgpu :)";
1918
+ return reinterpret_cast<ggml_guid_t>((void *) guid_str);
1919
+ }
1920
+
1921
+ // Workgroup size is a common constant
1922
+ static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
1923
+ std::vector<wgpu::ConstantEntry> constants(1);
1924
+ constants[0].key = "wg_size";
1925
+ constants[0].value = wg_size;
1926
+ return constants;
1927
+ }
1928
+
1929
+ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
1930
+ // we use the maximum workgroup size for the memset pipeline
1931
+ size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
1932
+ // Size the bytes_per_thread so that the largest buffer size can be handled
1933
+ webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads);
1934
+ std::vector<wgpu::ConstantEntry> constants(2);
1935
+ constants[0].key = "wg_size";
1936
+ constants[0].value = WEBGPU_MAX_WG_SIZE;
1937
+ constants[1].key = "bytes_per_thread";
1938
+ constants[1].value = webgpu_ctx->memset_bytes_per_thread;
1939
+ webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants);
1940
+ }
1941
+
1942
+ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
1943
+ // Q4/Q5/Q8 classic quantizations
1944
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
1945
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
1946
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
1947
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
1948
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
1949
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
1950
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
1951
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
1952
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
1953
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
1954
+
1955
+ // K-quantizations
1956
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
1957
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
1958
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
1959
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
1960
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
1961
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
1962
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
1963
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
1964
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
1965
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
1966
+
1967
+ // IQ quantizations (2-, 3-, 4-bit variants)
1968
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
1969
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
1970
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
1971
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
1972
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
1973
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
1974
+
1975
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
1976
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
1977
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
1978
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
1979
+
1980
+ // 1-bit and 4-bit IQ variants
1981
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
1982
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
1983
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
1984
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
1985
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
1986
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
1987
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
1988
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
1989
+
1990
+ std::string proc_mul_mat_f32_f32;
1991
+ std::string proc_mul_mat_f32_f32_vec;
1992
+ std::string proc_mul_mat_f16_f32;
1993
+ std::string proc_mul_mat_f16_f32_vec;
1994
+ std::string proc_mul_mat_f16_f16;
1995
+ std::string proc_mul_mat_f16_f16_vec;
1996
+ std::string proc_mul_mat_q4_0_f32;
1997
+ std::string proc_mul_mat_q4_0_f32_vec;
1998
+
1999
+ std::vector<wgpu::ConstantEntry> mul_mat_constants;
2000
+ #ifndef __EMSCRIPTEN__
2001
+ if (webgpu_ctx->supports_subgroup_matrix) {
2002
+ std::map<std::string, std::string> sg_matrix_repls;
2003
+ sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
2004
+ sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
2005
+ sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
2006
+ sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
2007
+ sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
2008
+ sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
2009
+ sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
2010
+ sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
2011
+ sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
2012
+
2013
+ proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
2014
+ proc_mul_mat_f32_f32_vec =
2015
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
2016
+ proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
2017
+ proc_mul_mat_f16_f32_vec =
2018
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
2019
+ proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
2020
+ proc_mul_mat_f16_f16_vec =
2021
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
2022
+ proc_mul_mat_q4_0_f32 =
2023
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
2024
+ proc_mul_mat_q4_0_f32_vec =
2025
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
2026
+ } else {
2027
+ #endif
2028
+ mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
2029
+ mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
2030
+ mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
2031
+
2032
+ std::map<std::string, std::string> reg_repls;
2033
+ reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
2034
+ reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
2035
+
2036
+ proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
2037
+ proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
2038
+ proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
2039
+ proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
2040
+ proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
2041
+ proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
2042
+ proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
2043
+ proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
2044
+ #ifndef __EMSCRIPTEN__
2045
+ }
2046
+ #endif
2047
+
2048
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2049
+ webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
2050
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2051
+ webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
2052
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2053
+ webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
2054
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2055
+ webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
2056
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2057
+ webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
2058
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2059
+ webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
2060
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2061
+ webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
2062
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2063
+ webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
2064
+
2065
+ std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
2066
+ mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
2067
+ mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2068
+ mul_mat_vec_constants[1].key = "TILE_K";
2069
+ mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
2070
+ mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
2071
+ mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
2072
+
2073
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2074
+ webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
2075
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2076
+ webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
2077
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2078
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
2079
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2080
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
2081
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2082
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
2083
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2084
+ webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
2085
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2086
+ webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
2087
+ }
2088
+
2089
+ static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
2090
+ webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline(
2091
+ webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
2092
+ webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline(
2093
+ webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
2094
+ }
2095
+
2096
+ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
2097
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2098
+
2099
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
2100
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
2101
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] =
2102
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
2103
+
2104
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
2105
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
2106
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
2107
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
2108
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
2109
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
2110
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
2111
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
2112
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
2113
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
2114
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
2115
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
2116
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
2117
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
2118
+
2119
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
2120
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
2121
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
2122
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
2123
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
2124
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
2125
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
2126
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
2127
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
2128
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
2129
+
2130
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] =
2131
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
2132
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
2133
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
2134
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
2135
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
2136
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] =
2137
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
2138
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
2139
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
2140
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
2141
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
2142
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
2143
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
2144
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
2145
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
2146
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
2147
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
2148
+ }
2149
+
2150
+ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
2151
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2152
+
2153
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
2154
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2155
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
2156
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2157
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
2158
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2159
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
2160
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
2161
+ }
2162
+
2163
+ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
2164
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2165
+
2166
+ webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
2167
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants);
2168
+ webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
2169
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants);
2170
+ webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
2171
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
2172
+ webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
2173
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
2174
+ }
2175
+
2176
+ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
2177
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2178
+
2179
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
2180
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants);
2181
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
2182
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants);
2183
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
2184
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
2185
+ webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
2186
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
2187
+ }
2188
+
2189
+ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
2190
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2191
+
2192
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
2193
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants);
2194
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
2195
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants);
2196
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
2197
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
2198
+ webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
2199
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
2200
+ }
2201
+
2202
+ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
2203
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2204
+
2205
+ webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
2206
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants);
2207
+ webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
2208
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants);
2209
+ webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
2210
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
2211
+ webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
2212
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
2213
+ }
2214
+
2215
+ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
2216
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2217
+
2218
+ webgpu_ctx->rms_norm_pipelines[0] =
2219
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2220
+ webgpu_ctx->rms_norm_pipelines[1] =
2221
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2222
+ }
2223
+
2224
+ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
2225
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2226
+
2227
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
2228
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2229
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] =
2230
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2231
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
2232
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2233
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] =
2234
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2235
+
2236
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
2237
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2238
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] =
2239
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2240
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
2241
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2242
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] =
2243
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2244
+ }
2245
+
2246
+ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
2247
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2248
+
2249
+ // REGLU
2250
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
2251
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2252
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
2253
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2254
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
2255
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2256
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
2257
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2258
+
2259
+ // GEGLU
2260
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
2261
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2262
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
2263
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2264
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
2265
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2266
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
2267
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2268
+
2269
+ // SWIGLU
2270
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
2271
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2272
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
2273
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2274
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] =
2275
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2276
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] =
2277
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2278
+
2279
+ // SWIGLU_OAI
2280
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
2281
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2282
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] =
2283
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2284
+
2285
+ // GEGLU_ERF
2286
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
2287
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2288
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
2289
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2290
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] =
2291
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2292
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] =
2293
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2294
+
2295
+ // GEGLU_QUICK
2296
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
2297
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2298
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
2299
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2300
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] =
2301
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2302
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] =
2303
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2304
+ }
2305
+
2306
+ static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
2307
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2308
+
2309
+ // ABS
2310
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] =
2311
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants);
2312
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] =
2313
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants);
2314
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] =
2315
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants);
2316
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] =
2317
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants);
2318
+
2319
+ // SGN
2320
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] =
2321
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants);
2322
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] =
2323
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants);
2324
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] =
2325
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants);
2326
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] =
2327
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants);
2328
+
2329
+ // NEG
2330
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] =
2331
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants);
2332
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] =
2333
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants);
2334
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] =
2335
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants);
2336
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] =
2337
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants);
2338
+
2339
+ // STEP
2340
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] =
2341
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants);
2342
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] =
2343
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants);
2344
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] =
2345
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants);
2346
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] =
2347
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants);
2348
+
2349
+ // TANH
2350
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] =
2351
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants);
2352
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] =
2353
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants);
2354
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] =
2355
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants);
2356
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] =
2357
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants);
2358
+
2359
+ // ELU
2360
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] =
2361
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants);
2362
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] =
2363
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants);
2364
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] =
2365
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants);
2366
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] =
2367
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants);
2368
+
2369
+ // RELU
2370
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] =
2371
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants);
2372
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] =
2373
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants);
2374
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] =
2375
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants);
2376
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] =
2377
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants);
2378
+
2379
+ // SIGMOID
2380
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] =
2381
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants);
2382
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] =
2383
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants);
2384
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] =
2385
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants);
2386
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] =
2387
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants);
2388
+
2389
+ // GELU
2390
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] =
2391
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants);
2392
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] =
2393
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants);
2394
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] =
2395
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants);
2396
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] =
2397
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants);
2398
+
2399
+ // GELU_QUICK
2400
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] =
2401
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
2402
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] =
2403
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
2404
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2405
+ webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants);
2406
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2407
+ webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants);
2408
+
2409
+ // SILU
2410
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] =
2411
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants);
2412
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] =
2413
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants);
2414
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] =
2415
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants);
2416
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] =
2417
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants);
2418
+
2419
+ // HARDSWISH
2420
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] =
2421
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants);
2422
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] =
2423
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants);
2424
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] =
2425
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants);
2426
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] =
2427
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants);
2428
+
2429
+ // HARDSIGMOID
2430
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] =
2431
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
2432
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] =
2433
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
2434
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2435
+ webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants);
2436
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2437
+ webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants);
2438
+
2439
+ // EXP
2440
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] =
2441
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants);
2442
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] =
2443
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants);
2444
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] =
2445
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants);
2446
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] =
2447
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants);
2448
+
2449
+ // GELU_ERF
2450
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] =
2451
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
2452
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] =
2453
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
2454
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] =
2455
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants);
2456
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] =
2457
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants);
2458
+
2459
+ // XIELU
2460
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] =
2461
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants);
2462
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] =
2463
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants);
2464
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] =
2465
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants);
2466
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] =
2467
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants);
2468
+
2469
+ // CEIL
2470
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] =
2471
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants);
2472
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] =
2473
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants);
2474
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] =
2475
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants);
2476
+ webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] =
2477
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants);
2478
+ }
2479
+
2480
+ static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
2481
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2482
+
2483
+ webgpu_ctx->scale_pipelines[0] =
2484
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants);
2485
+ webgpu_ctx->scale_pipelines[1] =
2486
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants);
2487
+ }
2488
+
2489
+ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
2490
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2491
+
2492
+ // f32 (no mask)
2493
+ webgpu_ctx->soft_max_pipelines[2][0][0] =
2494
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2495
+ webgpu_ctx->soft_max_pipelines[2][0][1] =
2496
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2497
+ webgpu_ctx->soft_max_pipelines[2][1][0] =
2498
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2499
+ webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
2500
+ webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2501
+
2502
+ // f32 mask (mask_type = 0)
2503
+ webgpu_ctx->soft_max_pipelines[0][0][0] =
2504
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2505
+ webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
2506
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2507
+ webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
2508
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2509
+ webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline(
2510
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants);
2511
+
2512
+ // f16 mask (mask_type = 1)
2513
+ webgpu_ctx->soft_max_pipelines[1][0][0] =
2514
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2515
+ webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
2516
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2517
+ webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
2518
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2519
+ webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline(
2520
+ webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
2521
+ }
2522
+
2523
+ // TODO: move most initialization logic here
2524
+ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
2525
+ GGML_UNUSED(params);
2526
+
2527
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
2528
+
2529
+ ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2530
+ webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
2531
+
2532
+ static ggml_backend_webgpu_context backend_ctx;
2533
+ backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
2534
+ backend_ctx.webgpu_ctx = webgpu_ctx;
2535
+
2536
+ // See GGML Backend Interface section
2537
+ static ggml_backend backend = {
2538
+ /* .guid = */ ggml_backend_webgpu_guid(),
2539
+ /* .interface = */ ggml_backend_webgpu_i,
2540
+ /* .device = */ dev,
2541
+ /* .context = */ &backend_ctx,
2542
+ };
2543
+ return &backend;
2544
+ }
2545
+
2546
+ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
2547
+ // See GGML Backend Buffer Type Interface section
2548
+
2549
+ static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
2550
+ /* .iface = */ {
2551
+ /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
2552
+ /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
2553
+ /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
2554
+ /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
2555
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2556
+ /* .is_host = */ NULL, // defaults to false
2557
+ },
2558
+ /* .device = */
2559
+ dev,
2560
+ /* .context = */ NULL,
2561
+ };
2562
+
2563
+ return &ggml_backend_webgpu_buffer_type;
2564
+ }
2565
+
2566
+ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2567
+ GGML_UNUSED(dev);
2568
+ return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
2569
+ }
2570
+
2571
+ static bool ggml_webgpu_supported_qtype(ggml_type type) {
2572
+ switch (type) {
2573
+ case GGML_TYPE_Q4_0:
2574
+ case GGML_TYPE_Q4_1:
2575
+ case GGML_TYPE_Q5_0:
2576
+ case GGML_TYPE_Q5_1:
2577
+ case GGML_TYPE_Q8_0:
2578
+ case GGML_TYPE_Q2_K:
2579
+ case GGML_TYPE_Q3_K:
2580
+ case GGML_TYPE_Q4_K:
2581
+ case GGML_TYPE_Q5_K:
2582
+ case GGML_TYPE_Q6_K:
2583
+ case GGML_TYPE_IQ2_XXS:
2584
+ case GGML_TYPE_IQ2_XS:
2585
+ case GGML_TYPE_IQ2_S:
2586
+ case GGML_TYPE_IQ3_XXS:
2587
+ case GGML_TYPE_IQ3_S:
2588
+ case GGML_TYPE_IQ1_S:
2589
+ case GGML_TYPE_IQ1_M:
2590
+ case GGML_TYPE_IQ4_NL:
2591
+ case GGML_TYPE_IQ4_XS:
2592
+ return true;
2593
+ default:
2594
+ return false;
2595
+ }
2596
+ }
2597
+
2598
+ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2599
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2600
+
2601
+ webgpu_context webgpu_ctx = ctx->webgpu_ctx;
2602
+
2603
+ ggml_tensor * src0 = op->src[0];
2604
+ ggml_tensor * src1 = op->src[1];
2605
+ ggml_tensor * src2 = op->src[2];
2606
+
2607
+ // on smaller devices (or CI), tensors may be larger than the max storage buffer size
2608
+ if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
2609
+ (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2610
+ (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
2611
+ return false;
2612
+ }
2613
+
2614
+ bool supports_op = false;
2615
+ switch (op->op) {
2616
+ case GGML_OP_NONE:
2617
+ case GGML_OP_VIEW:
2618
+ case GGML_OP_PERMUTE:
2619
+ case GGML_OP_TRANSPOSE:
2620
+ case GGML_OP_RESHAPE:
2621
+ supports_op = true;
2622
+ break;
2623
+ case GGML_OP_ADD:
2624
+ case GGML_OP_SUB:
2625
+ case GGML_OP_MUL:
2626
+ case GGML_OP_DIV:
2627
+ // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
2628
+ // see https://github.com/ggml-org/llama.cpp/pull/16857
2629
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
2630
+ (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
2631
+ break;
2632
+ case GGML_OP_CPY:
2633
+ case GGML_OP_CONT:
2634
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
2635
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
2636
+ break;
2637
+ case GGML_OP_SET_ROWS:
2638
+ supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
2639
+ break;
2640
+ case GGML_OP_GET_ROWS:
2641
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
2642
+ ggml_webgpu_supported_qtype(src0->type)) {
2643
+ supports_op = (op->type == GGML_TYPE_F32);
2644
+ }
2645
+ break;
2646
+ case GGML_OP_MUL_MAT:
2647
+ {
2648
+ switch (src1->type) {
2649
+ case GGML_TYPE_F16:
2650
+ supports_op |= (src0->type == GGML_TYPE_F16);
2651
+ break;
2652
+ case GGML_TYPE_F32:
2653
+ switch (src0->type) {
2654
+ case GGML_TYPE_F32:
2655
+ case GGML_TYPE_F16:
2656
+ case GGML_TYPE_Q4_0:
2657
+ case GGML_TYPE_Q4_1:
2658
+ case GGML_TYPE_Q5_0:
2659
+ case GGML_TYPE_Q5_1:
2660
+ case GGML_TYPE_Q8_0:
2661
+ case GGML_TYPE_Q2_K:
2662
+ case GGML_TYPE_Q3_K:
2663
+ case GGML_TYPE_Q4_K:
2664
+ case GGML_TYPE_Q5_K:
2665
+ case GGML_TYPE_Q6_K:
2666
+ case GGML_TYPE_IQ2_XXS:
2667
+ case GGML_TYPE_IQ2_XS:
2668
+ case GGML_TYPE_IQ2_S:
2669
+ case GGML_TYPE_IQ3_XXS:
2670
+ case GGML_TYPE_IQ3_S:
2671
+ case GGML_TYPE_IQ1_S:
2672
+ case GGML_TYPE_IQ1_M:
2673
+ case GGML_TYPE_IQ4_NL:
2674
+ case GGML_TYPE_IQ4_XS:
2675
+ supports_op = true;
2676
+ break;
2677
+ default:
2678
+ break;
2679
+ }
2680
+ default:
2681
+ break;
2682
+ }
2683
+ break;
2684
+ }
2685
+ case GGML_OP_FLASH_ATTN_EXT:
2686
+ {
2687
+ if (!webgpu_ctx->supports_subgroup_matrix) {
2688
+ break;
2689
+ }
2690
+ // Head dimensions must fit in workgroup memory with minimum tile sizes
2691
+ size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
2692
+ const bool has_mask = op->src[3] != nullptr;
2693
+ const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
2694
+ (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
2695
+ const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
2696
+ webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
2697
+ has_mask, kv_direct);
2698
+ if (min_bytes > limit_bytes) {
2699
+ break;
2700
+ }
2701
+
2702
+ supports_op = src0->type == GGML_TYPE_F32 &&
2703
+ (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
2704
+ src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
2705
+ src2->type == src1->type && op->type == GGML_TYPE_F32;
2706
+ break;
2707
+ }
2708
+ case GGML_OP_RMS_NORM:
2709
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
2710
+ break;
2711
+ case GGML_OP_ROPE:
2712
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
2713
+ break;
2714
+ case GGML_OP_GLU:
2715
+ switch (ggml_get_glu_op(op)) {
2716
+ case GGML_GLU_OP_REGLU:
2717
+ case GGML_GLU_OP_GEGLU:
2718
+ case GGML_GLU_OP_SWIGLU:
2719
+ case GGML_GLU_OP_GEGLU_ERF:
2720
+ case GGML_GLU_OP_GEGLU_QUICK:
2721
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
2722
+ break;
2723
+ case GGML_GLU_OP_SWIGLU_OAI:
2724
+ supports_op = op->type == GGML_TYPE_F32;
2725
+ break;
2726
+ default:
2727
+ break;
2728
+ }
2729
+ break;
2730
+ case GGML_OP_SCALE:
2731
+ supports_op = op->type == GGML_TYPE_F32;
2732
+ break;
2733
+ case GGML_OP_SOFT_MAX:
2734
+ supports_op = op->type == GGML_TYPE_F32;
2735
+ break;
2736
+ case GGML_OP_UNARY:
2737
+ {
2738
+ const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
2739
+
2740
+ switch (UNARY_OP) {
2741
+ case GGML_UNARY_OP_ABS:
2742
+ case GGML_UNARY_OP_SGN:
2743
+ case GGML_UNARY_OP_NEG:
2744
+ case GGML_UNARY_OP_STEP:
2745
+ case GGML_UNARY_OP_TANH:
2746
+ case GGML_UNARY_OP_ELU:
2747
+ case GGML_UNARY_OP_RELU:
2748
+ case GGML_UNARY_OP_SIGMOID:
2749
+ case GGML_UNARY_OP_GELU:
2750
+ case GGML_UNARY_OP_GELU_QUICK:
2751
+ case GGML_UNARY_OP_SILU:
2752
+ case GGML_UNARY_OP_HARDSWISH:
2753
+ case GGML_UNARY_OP_HARDSIGMOID:
2754
+ case GGML_UNARY_OP_EXP:
2755
+ case GGML_UNARY_OP_GELU_ERF:
2756
+ case GGML_UNARY_OP_XIELU:
2757
+ case GGML_UNARY_OP_CEIL:
2758
+ supports_op = supports_op =
2759
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2760
+ break;
2761
+ default:
2762
+ break;
2763
+ }
2764
+ }
2765
+ break;
2766
+
2767
+ default:
2768
+ break;
2769
+ }
2770
+ if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
2771
+ (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2772
+ (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
2773
+ (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
2774
+ supports_op = false;
2775
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
2776
+ }
2777
+
2778
+ if (!supports_op) {
2779
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
2780
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
2781
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
2782
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
2783
+ } else {
2784
+ WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
2785
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
2786
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
2787
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
2788
+ }
2789
+ return supports_op;
2790
+ }
2791
+
2792
+ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
2793
+ /* .get_name = */ ggml_backend_webgpu_device_get_name,
2794
+ /* .get_description = */ ggml_backend_webgpu_device_get_description,
2795
+ /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
2796
+ /* .get_type = */ ggml_backend_webgpu_device_get_type,
2797
+ /* .get_props = */ ggml_backend_webgpu_device_get_props,
2798
+ /* .init_backend = */ ggml_backend_webgpu_device_init,
2799
+ /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
2800
+ /* .get_host_buffer_type = */ NULL,
2801
+ /* .buffer_from_host_ptr = */ NULL,
2802
+ /* .supports_op = */ ggml_backend_webgpu_device_supports_op,
2803
+ /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
2804
+ /* .offload_op = */ NULL,
2805
+ /* .event_new = */ NULL,
2806
+ /* .event_free = */ NULL,
2807
+ /* .event_synchronize = */ NULL,
2808
+ };
2809
+
2810
+ /* End GGML Backend Device Interface */
2811
+
2812
+ /* GGML Backend Registration Interface */
2813
+
2814
+ static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
2815
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
2816
+ return ctx->name;
2817
+ }
2818
+
2819
+ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
2820
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
2821
+ return ctx->device_count;
2822
+ }
2823
+
2824
+ // TODO: Does this need to be thread safe? Is it only called once?
2825
+ // TODO: move most logic to device_init function so backend can be freed/initialized properly
2826
+ // Only one device is supported for now
2827
+ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2828
+ GGML_ASSERT(index == 0);
2829
+ WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
2830
+
2831
+ WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
2832
+
2833
+ ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
2834
+
2835
+ webgpu_context ctx = reg_ctx->webgpu_ctx;
2836
+
2837
+ wgpu::RequestAdapterOptions options = {};
2838
+
2839
+ #ifndef __EMSCRIPTEN__
2840
+ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2841
+ const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2842
+ wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2843
+ adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
2844
+ adapterTogglesDesc.enabledToggleCount = 2;
2845
+ options.nextInChain = &adapterTogglesDesc;
2846
+ #endif
2847
+
2848
+ ctx->instance.WaitAny(ctx->instance.RequestAdapter(
2849
+ &options, wgpu::CallbackMode::AllowSpontaneous,
2850
+ [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
2851
+ if (status != wgpu::RequestAdapterStatus::Success) {
2852
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
2853
+ return;
2854
+ }
2855
+ ctx->adapter = std::move(adapter);
2856
+ }),
2857
+ UINT64_MAX);
2858
+ GGML_ASSERT(ctx->adapter != nullptr);
2859
+
2860
+ ctx->adapter.GetLimits(&ctx->limits);
2861
+
2862
+ wgpu::AdapterInfo info{};
2863
+ #ifndef __EMSCRIPTEN__
2864
+ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2865
+ if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2866
+ info.nextInChain = &subgroup_matrix_configs;
2867
+ }
2868
+ #endif
2869
+ ctx->adapter.GetInfo(&info);
2870
+
2871
+ wgpu::SupportedFeatures features;
2872
+ ctx->adapter.GetFeatures(&features);
2873
+ // we require f16 support
2874
+ GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2875
+
2876
+ #ifndef __EMSCRIPTEN__
2877
+ // Only support square f16 matrices of size 8 or 16 for now
2878
+ bool valid_subgroup_matrix_config = false;
2879
+ if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2880
+ for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2881
+ const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2882
+ if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2883
+ config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2884
+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2885
+ ctx->sg_mat_m = config.M;
2886
+ ctx->sg_mat_n = config.N;
2887
+ ctx->sg_mat_k = config.K;
2888
+ valid_subgroup_matrix_config = true;
2889
+ break;
2890
+ }
2891
+ }
2892
+ }
2893
+
2894
+ ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
2895
+ #endif
2896
+ // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2897
+ // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2898
+ ctx->max_subgroup_size = info.subgroupMaxSize;
2899
+
2900
+ // Initialize device
2901
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
2902
+
2903
+ #ifndef __EMSCRIPTEN__
2904
+ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2905
+ if (ctx->supports_subgroup_matrix) {
2906
+ required_features.push_back(wgpu::FeatureName::Subgroups);
2907
+ required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2908
+ }
2909
+ #endif
2910
+
2911
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2912
+ required_features.push_back(wgpu::FeatureName::TimestampQuery);
2913
+ #endif
2914
+
2915
+ wgpu::DeviceDescriptor dev_desc;
2916
+ dev_desc.requiredLimits = &ctx->limits;
2917
+ dev_desc.requiredFeatures = required_features.data();
2918
+ dev_desc.requiredFeatureCount = required_features.size();
2919
+ dev_desc.SetDeviceLostCallback(
2920
+ wgpu::CallbackMode::AllowSpontaneous,
2921
+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
2922
+ GGML_UNUSED(device);
2923
+ GGML_UNUSED(reason);
2924
+ GGML_UNUSED(message);
2925
+ //TODO: uncomment once proper free logic is in place
2926
+ //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2927
+ //std::string(message).c_str());
2928
+ });
2929
+ dev_desc.SetUncapturedErrorCallback(
2930
+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
2931
+ GGML_UNUSED(device);
2932
+ GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2933
+ std::string(message).c_str());
2934
+ });
2935
+
2936
+ #ifndef __EMSCRIPTEN__
2937
+ // Enable Dawn-specific toggles to increase native performance
2938
+ // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2939
+ // only for native performance?
2940
+ const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2941
+ "disable_polyfills_on_integer_div_and_mod" };
2942
+ const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2943
+ wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2944
+ deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2945
+ deviceTogglesDesc.enabledToggleCount = 4;
2946
+ deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2947
+ deviceTogglesDesc.disabledToggleCount = 1;
2948
+
2949
+ dev_desc.nextInChain = &deviceTogglesDesc;
2950
+ #endif
2951
+
2952
+ ctx->instance.WaitAny(ctx->adapter.RequestDevice(
2953
+ &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
2954
+ [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
2955
+ if (status != wgpu::RequestDeviceStatus::Success) {
2956
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
2957
+ std::string(message).c_str());
2958
+ return;
2959
+ }
2960
+ ctx->device = std::move(device);
2961
+ }),
2962
+ UINT64_MAX);
2963
+ GGML_ASSERT(ctx->device != nullptr);
2964
+
2965
+ // Initialize (compute) queue
2966
+ ctx->queue = ctx->device.GetQueue();
2967
+
2968
+ // Create buffer pool for shader parameters
2969
+ ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2970
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2971
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2972
+
2973
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2974
+ // Initialize buffer pool for timestamp queries (profiling)
2975
+ ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
2976
+ WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2977
+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2978
+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2979
+ #endif
2980
+
2981
+ ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
2982
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
2983
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
2984
+
2985
+ ggml_webgpu_init_memset_pipeline(ctx);
2986
+ ggml_webgpu_init_mul_mat_pipeline(ctx);
2987
+ ggml_webgpu_init_set_rows_pipeline(ctx);
2988
+ ggml_webgpu_init_get_rows_pipeline(ctx);
2989
+ ggml_webgpu_init_cpy_pipeline(ctx);
2990
+ ggml_webgpu_init_add_pipeline(ctx);
2991
+ ggml_webgpu_init_sub_pipeline(ctx);
2992
+ ggml_webgpu_init_mul_pipeline(ctx);
2993
+ ggml_webgpu_init_div_pipeline(ctx);
2994
+ ggml_webgpu_init_rms_norm_pipeline(ctx);
2995
+ ggml_webgpu_init_rope_pipeline(ctx);
2996
+ ggml_webgpu_init_glu_pipeline(ctx);
2997
+ ggml_webgpu_init_scale_pipeline(ctx);
2998
+ ggml_webgpu_init_soft_max_pipeline(ctx);
2999
+ ggml_webgpu_init_unary_pipeline(ctx);
3000
+
3001
+ #ifdef GGML_WEBGPU_DEBUG
3002
+ // Initialize debug buffers
3003
+ ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3004
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
3005
+ ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3006
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
3007
+ #endif
3008
+
3009
+ static ggml_backend_webgpu_device_context device_ctx;
3010
+ device_ctx.webgpu_ctx = ctx;
3011
+ device_ctx.device_name = GGML_WEBGPU_NAME;
3012
+ device_ctx.device_desc = info.description;
3013
+
3014
+ GGML_LOG_INFO(
3015
+ "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
3016
+ "device_desc: %s\n",
3017
+ info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
3018
+ std::string(info.device).c_str(), std::string(info.description).c_str());
3019
+
3020
+ // See GGML Backend Device Interface section
3021
+ static ggml_backend_device device = {
3022
+ /* .iface = */ ggml_backend_webgpu_device_i,
3023
+ /* .reg = */ reg,
3024
+ /* .context = */ &device_ctx,
3025
+ };
3026
+
3027
+ WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx);
3028
+ return &device;
3029
+ }
3030
+
3031
+ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
3032
+ /* .get_name = */ ggml_backend_webgpu_reg_get_name,
3033
+ /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
3034
+ /* .get_device = */ ggml_backend_webgpu_reg_get_device,
3035
+ /* .get_proc_address = */ NULL,
3036
+ };
3037
+
3038
+ /* End GGML Backend Registration Interface */
3039
+
3040
+ ggml_backend_reg_t ggml_backend_webgpu_reg() {
3041
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
3042
+
3043
+ webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
3044
+
3045
+ static ggml_backend_webgpu_reg_context ctx;
3046
+ ctx.webgpu_ctx = webgpu_ctx;
3047
+ ctx.name = GGML_WEBGPU_NAME;
3048
+ ctx.device_count = 1;
3049
+
3050
+ wgpu::InstanceDescriptor instance_descriptor{};
3051
+ std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
3052
+ instance_descriptor.requiredFeatures = instance_features.data();
3053
+ instance_descriptor.requiredFeatureCount = instance_features.size();
3054
+
3055
+ #ifndef __EMSCRIPTEN__
3056
+ const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
3057
+ wgpu::DawnTogglesDescriptor instanceTogglesDesc;
3058
+ instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
3059
+ instanceTogglesDesc.enabledToggleCount = 1;
3060
+ instance_descriptor.nextInChain = &instanceTogglesDesc;
3061
+ #endif
3062
+
3063
+ webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
3064
+
3065
+ #ifdef __EMSCRIPTEN__
3066
+ if (webgpu_ctx->instance == nullptr) {
3067
+ GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
3068
+ return nullptr;
3069
+ }
3070
+ #endif
3071
+ GGML_ASSERT(webgpu_ctx->instance != nullptr);
3072
+
3073
+ static ggml_backend_reg reg = {
3074
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
3075
+ /* .iface = */ ggml_backend_webgpu_reg_i,
3076
+ /* .context = */ &ctx,
3077
+ };
3078
+ return &reg;
3079
+ }
3080
+
3081
+ ggml_backend_t ggml_backend_webgpu_init(void) {
3082
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
3083
+
3084
+ return ggml_backend_webgpu_device_init(dev, nullptr);
3085
+ }
3086
+
3087
+ GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)