whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -1,20 +1,273 @@
1
1
  #ifndef GGML_WEBGPU_SHADER_LIB_HPP
2
2
  #define GGML_WEBGPU_SHADER_LIB_HPP
3
3
 
4
+ #include "ggml-wgsl-shaders.hpp"
4
5
  #include "ggml.h"
5
6
  #include "pre_wgsl.hpp"
6
7
 
8
+ #include <webgpu/webgpu_cpp.h>
9
+
10
+ #include <algorithm>
11
+ #include <memory>
7
12
  #include <string>
13
+ #include <unordered_map>
8
14
  #include <vector>
9
15
 
10
16
  #define GGML_WEBGPU_F16_SIZE_BYTES 2
11
17
  #define GGML_WEBGPU_F32_SIZE_BYTES 4
18
+ #define GGML_WEBGPU_I32_SIZE_BYTES 4
12
19
  #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
13
20
  #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
14
21
  // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
15
22
  #define GGML_WEBGPU_KV_SEQ_PAD 256u
16
23
 
17
- struct ggml_webgpu_flash_attn_shader_lib_context {
24
+ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
25
+
26
+ // Matrix multiplication parameters
27
+
28
+ // Register tiling parameters
29
+ #define WEBGPU_MUL_MAT_TILE_M 8
30
+ #define WEBGPU_MUL_MAT_TILE_N 8
31
+ #define WEBGPU_MUL_MAT_WG_SIZE_M 8
32
+ #define WEBGPU_MUL_MAT_WG_SIZE_N 8
33
+ #define WEBGPU_MUL_MAT_TILE_K 32
34
+
35
+ // Subgroup matrix parameters
36
+ // The number of subgroups in the M dimension
37
+ #define WEBGPU_MUL_MAT_SUBGROUP_M 2
38
+ // The number of subgroups in the N dimension
39
+ #define WEBGPU_MUL_MAT_SUBGROUP_N 2
40
+ // The number of subgroup matrices each subgroup accumulates over
41
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
42
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
43
+
44
+ // Matrix-vector multiplication parameters
45
+ #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
46
+
47
+ // Must be multiple of 4 to work with vectorized paths, and must divide
48
+ // mul_mat_vec wg size
49
+ #define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
50
+ #define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
51
+
52
+ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
53
+ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
54
+
55
+ // Requires 32 threads per output (wg_size/outputs_per_wg == 32)
56
+ #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
57
+ // Requires at least two (and multiple of 2) k-quant blocks per tile
58
+ #define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
59
+
60
+ // default size for legacy matrix multiplication
61
+ #define WEBGPU_MUL_MAT_WG_SIZE 256
62
+
63
+ // Same hash combine function as in boost
64
+ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
65
+ seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
66
+ }
67
+
68
+ struct ggml_webgpu_shader_lib_context {
69
+ ggml_tensor * src0;
70
+ ggml_tensor * src1;
71
+ ggml_tensor * src2;
72
+ ggml_tensor * src3;
73
+ ggml_tensor * src4;
74
+ ggml_tensor * dst;
75
+
76
+ uint32_t max_wg_size;
77
+ size_t wg_mem_limit_bytes = 0;
78
+ bool inplace = false;
79
+ bool overlap = false;
80
+ bool src_overlap = false;
81
+ bool supports_subgroup_matrix = false;
82
+ uint32_t sg_mat_m = 0;
83
+ uint32_t sg_mat_n = 0;
84
+ uint32_t sg_mat_k = 0;
85
+ uint32_t max_subgroup_size = 0;
86
+ };
87
+
88
+ struct webgpu_pipeline {
89
+ wgpu::ComputePipeline pipeline;
90
+ std::string name;
91
+ std::shared_ptr<void> context = nullptr;
92
+ };
93
+
94
+ struct ggml_webgpu_generic_shader_decisions {
95
+ uint32_t wg_size = 0;
96
+ };
97
+
98
+ /** Argsort **/
99
+
100
+ struct ggml_webgpu_argsort_shader_lib_context {
101
+ uint32_t max_wg_size;
102
+ size_t wg_mem_limit_bytes;
103
+ int32_t order;
104
+ };
105
+
106
+ /** Set Rows **/
107
+
108
+ struct ggml_webgpu_set_rows_pipeline_key {
109
+ int dst_type;
110
+ int vec4;
111
+ int i64_idx;
112
+
113
+ bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
114
+ return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
115
+ }
116
+ };
117
+
118
+ struct ggml_webgpu_set_rows_pipeline_key_hash {
119
+ size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
120
+ size_t seed = 0;
121
+ ggml_webgpu_hash_combine(seed, key.dst_type);
122
+ ggml_webgpu_hash_combine(seed, key.vec4);
123
+ ggml_webgpu_hash_combine(seed, key.i64_idx);
124
+ return seed;
125
+ }
126
+ };
127
+
128
+ struct ggml_webgpu_set_rows_shader_decisions {
129
+ bool vec4;
130
+ bool i64_idx;
131
+ uint32_t wg_size;
132
+ };
133
+
134
+ /** Get Rows **/
135
+
136
+ struct ggml_webgpu_get_rows_pipeline_key {
137
+ ggml_type src_type;
138
+ int vectorized;
139
+
140
+ bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
141
+ return src_type == other.src_type && vectorized == other.vectorized;
142
+ }
143
+ };
144
+
145
+ struct ggml_webgpu_get_rows_pipeline_key_hash {
146
+ size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
147
+ size_t seed = 0;
148
+ ggml_webgpu_hash_combine(seed, key.src_type);
149
+ ggml_webgpu_hash_combine(seed, key.vectorized);
150
+ return seed;
151
+ }
152
+ };
153
+
154
+ /** Pad **/
155
+ struct ggml_webgpu_pad_pipeline_key {
156
+ bool circular;
157
+
158
+ bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
159
+ };
160
+
161
+ struct ggml_webgpu_pad_pipeline_key_hash {
162
+ size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
163
+ size_t seed = 0;
164
+ ggml_webgpu_hash_combine(seed, key.circular);
165
+ return seed;
166
+ }
167
+ };
168
+
169
+ /** Scale **/
170
+
171
+ struct ggml_webgpu_scale_pipeline_key {
172
+ int inplace;
173
+
174
+ bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
175
+ };
176
+
177
+ struct ggml_webgpu_scale_pipeline_key_hash {
178
+ size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
179
+ size_t seed = 0;
180
+ ggml_webgpu_hash_combine(seed, key.inplace);
181
+ return seed;
182
+ }
183
+ };
184
+
185
+ /** Concat **/
186
+
187
+ struct ggml_webgpu_concat_pipeline_key {
188
+ int type;
189
+
190
+ bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
191
+ };
192
+
193
+ struct ggml_webgpu_concat_pipeline_key_hash {
194
+ size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
195
+ size_t seed = 0;
196
+ ggml_webgpu_hash_combine(seed, key.type);
197
+ return seed;
198
+ }
199
+ };
200
+
201
+ /** Repeat **/
202
+
203
+ struct ggml_webgpu_repeat_pipeline_key {
204
+ int type;
205
+
206
+ bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
207
+ };
208
+
209
+ struct ggml_webgpu_repeat_pipeline_key_hash {
210
+ size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
211
+ size_t seed = 0;
212
+ ggml_webgpu_hash_combine(seed, key.type);
213
+ return seed;
214
+ }
215
+ };
216
+
217
+ /** Binary **/
218
+
219
+ struct ggml_webgpu_binary_pipeline_key {
220
+ int type;
221
+ int op;
222
+ bool inplace;
223
+ bool overlap;
224
+ bool src_overlap;
225
+
226
+ bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
227
+ return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
228
+ src_overlap == other.src_overlap;
229
+ }
230
+ };
231
+
232
+ struct ggml_webgpu_binary_pipeline_key_hash {
233
+ size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
234
+ size_t seed = 0;
235
+ ggml_webgpu_hash_combine(seed, key.type);
236
+ ggml_webgpu_hash_combine(seed, key.op);
237
+ ggml_webgpu_hash_combine(seed, key.inplace);
238
+ ggml_webgpu_hash_combine(seed, key.overlap);
239
+ ggml_webgpu_hash_combine(seed, key.src_overlap);
240
+ return seed;
241
+ }
242
+ };
243
+
244
+ /** Unary **/
245
+
246
+ struct ggml_webgpu_unary_pipeline_key {
247
+ int type;
248
+ int op;
249
+ bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
250
+ bool inplace;
251
+
252
+ bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
253
+ return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
254
+ }
255
+ };
256
+
257
+ struct ggml_webgpu_unary_pipeline_key_hash {
258
+ size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
259
+ size_t seed = 0;
260
+ ggml_webgpu_hash_combine(seed, key.type);
261
+ ggml_webgpu_hash_combine(seed, key.op);
262
+ ggml_webgpu_hash_combine(seed, key.is_unary);
263
+ ggml_webgpu_hash_combine(seed, key.inplace);
264
+ return seed;
265
+ }
266
+ };
267
+
268
+ /** FlashAttention */
269
+
270
+ struct ggml_webgpu_flash_attn_pipeline_key {
18
271
  ggml_type kv_type;
19
272
  uint32_t head_dim_qk;
20
273
  uint32_t head_dim_v;
@@ -22,11 +275,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context {
22
275
  bool has_mask;
23
276
  bool has_sinks;
24
277
  bool uses_logit_softcap;
25
- uint32_t sg_mat_m;
26
- uint32_t sg_mat_n;
27
- uint32_t sg_mat_k;
28
- size_t wg_mem_limit_bytes;
29
- uint32_t max_subgroup_size;
278
+
279
+ bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
280
+ return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
281
+ kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
282
+ uses_logit_softcap == other.uses_logit_softcap;
283
+ }
284
+ };
285
+
286
+ struct ggml_webgpu_flash_attn_pipeline_key_hash {
287
+ size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
288
+ size_t seed = 0;
289
+ ggml_webgpu_hash_combine(seed, key.kv_type);
290
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
291
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
292
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
293
+ ggml_webgpu_hash_combine(seed, key.has_mask);
294
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
295
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
296
+ return seed;
297
+ }
298
+ };
299
+
300
+ struct ggml_webgpu_flash_attn_shader_lib_context {
301
+ ggml_webgpu_flash_attn_pipeline_key key;
302
+ uint32_t sg_mat_m;
303
+ uint32_t sg_mat_n;
304
+ uint32_t sg_mat_k;
305
+ size_t wg_mem_limit_bytes;
306
+ uint32_t max_subgroup_size;
30
307
  };
31
308
 
32
309
  struct ggml_webgpu_flash_attn_shader_decisions {
@@ -35,12 +312,6 @@ struct ggml_webgpu_flash_attn_shader_decisions {
35
312
  uint32_t wg_size = 0;
36
313
  };
37
314
 
38
- struct ggml_webgpu_processed_shader {
39
- std::string wgsl;
40
- std::string variant;
41
- ggml_webgpu_flash_attn_shader_decisions decisions;
42
- };
43
-
44
315
  // This is exposed because it's necessary in supports_op
45
316
  inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
46
317
  uint32_t kv_tile,
@@ -65,105 +336,1039 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
65
336
  return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
66
337
  }
67
338
 
68
- static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
69
- const size_t limit_bytes = context.wg_mem_limit_bytes;
70
- const size_t q_tile = context.sg_mat_m;
71
- const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
72
- 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
73
- size_t bytes_per_kv = 0;
74
- if (!context.kv_direct) {
75
- bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
339
+ /** Matrix Multiplication **/
340
+
341
+ struct ggml_webgpu_legacy_mul_mat_pipeline_key {
342
+ ggml_type src0_type;
343
+ ggml_type src1_type;
344
+
345
+ bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
346
+ return src0_type == other.src0_type && src1_type == other.src1_type;
76
347
  }
77
- if (context.has_mask) {
78
- bytes_per_kv += q_tile;
348
+ };
349
+
350
+ struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
351
+ size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
352
+ size_t seed = 0;
353
+ ggml_webgpu_hash_combine(seed, key.src0_type);
354
+ ggml_webgpu_hash_combine(seed, key.src1_type);
355
+ return seed;
79
356
  }
80
- bytes_per_kv += q_tile;
81
- bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
82
- const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
83
- return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
84
- }
357
+ };
85
358
 
86
- inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
87
- pre_wgsl::Preprocessor & preprocessor,
88
- const char * shader_src,
89
- const ggml_webgpu_flash_attn_shader_lib_context & context) {
90
- std::vector<std::string> defines;
91
- std::string variant = "flash_attn";
92
-
93
- switch (context.kv_type) {
94
- case GGML_TYPE_F32:
95
- defines.push_back("KV_F32");
96
- break;
97
- case GGML_TYPE_F16:
98
- defines.push_back("KV_F16");
99
- break;
100
- case GGML_TYPE_Q4_0:
101
- defines.push_back("KV_Q4_0");
102
- break;
103
- case GGML_TYPE_Q8_0:
104
- defines.push_back("KV_Q8_0");
105
- break;
106
- default:
107
- GGML_ABORT("Unsupported KV type for flash attention shader");
108
- }
109
- variant += std::string("_") + ggml_type_name(context.kv_type);
110
-
111
- if (context.has_mask) {
112
- defines.push_back("MASK");
113
- variant += "_mask";
114
- }
115
- if (context.has_sinks) {
116
- defines.push_back("SINKS");
117
- variant += "_sinks";
118
- }
119
- if (context.uses_logit_softcap) {
120
- defines.push_back("LOGIT_SOFTCAP");
121
- variant += "_lgsc";
122
- }
123
-
124
- if (context.kv_direct) {
125
- defines.push_back("KV_DIRECT");
126
- variant += "_kvdirect";
127
- }
128
-
129
- defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
130
- variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
131
-
132
- defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
133
- variant += std::string("_hsv") + std::to_string(context.head_dim_v);
134
-
135
- // For now these are not part of the variant name
136
- defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
137
- defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
138
- defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
139
-
140
- // Add chosen Q/KV tile sizes
141
- uint32_t q_tile = context.sg_mat_m;
142
- uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
143
- context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
144
- if (context.kv_direct) {
145
- GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
146
- // Avoids having to use bounds-checks and decreasing performance for direct KV loads
147
- while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
148
- kv_tile -= context.sg_mat_n;
149
- }
150
- }
151
-
152
- defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
153
- defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
154
-
155
- // workgroup size
156
- uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
157
-
158
- defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
159
-
160
- ggml_webgpu_processed_shader result;
161
- result.wgsl = preprocessor.preprocess(shader_src, defines);
162
- result.variant = variant;
163
- result.decisions.q_tile = q_tile;
164
- result.decisions.kv_tile = kv_tile;
165
- result.decisions.wg_size = wg_size;
166
- return result;
167
- }
359
+ struct ggml_webgpu_mul_mat_vec_pipeline_key {
360
+ ggml_type src0_type;
361
+ ggml_type src1_type;
362
+ int vectorized;
363
+
364
+ bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
365
+ return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
366
+ }
367
+ };
368
+
369
+ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
370
+ size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const {
371
+ size_t seed = 0;
372
+ ggml_webgpu_hash_combine(seed, key.src0_type);
373
+ ggml_webgpu_hash_combine(seed, key.src1_type);
374
+ ggml_webgpu_hash_combine(seed, key.vectorized);
375
+ return seed;
376
+ }
377
+ };
378
+
379
+ struct ggml_webgpu_mul_mat_vec_shader_decisions {
380
+ uint32_t wg_size;
381
+ uint32_t tile_k;
382
+ uint32_t outputs_per_wg;
383
+ uint32_t vec_size;
384
+ };
385
+
386
+ struct ggml_webgpu_mul_mat_pipeline_key {
387
+ ggml_type src0_type;
388
+ ggml_type src1_type;
389
+ int vectorized;
390
+ int use_subgroup_matrix;
391
+
392
+ bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
393
+ return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
394
+ use_subgroup_matrix == other.use_subgroup_matrix;
395
+ }
396
+ };
397
+
398
+ struct ggml_webgpu_mul_mat_pipeline_key_hash {
399
+ size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
400
+ size_t seed = 0;
401
+ ggml_webgpu_hash_combine(seed, key.src0_type);
402
+ ggml_webgpu_hash_combine(seed, key.src1_type);
403
+ ggml_webgpu_hash_combine(seed, key.vectorized);
404
+ ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
405
+ return seed;
406
+ }
407
+ };
408
+
409
+ struct ggml_webgpu_mul_mat_shader_decisions {
410
+ uint32_t tile_k;
411
+ uint32_t wg_size_m;
412
+ uint32_t wg_size_n;
413
+ uint32_t wg_size;
414
+ uint32_t outputs_per_wg;
415
+ int use_subgroup_matrix;
416
+
417
+ uint32_t tile_m;
418
+ uint32_t tile_n;
419
+
420
+ // Subgroup matrix parameters
421
+ uint32_t subgroup_m;
422
+ uint32_t subgroup_n;
423
+ uint32_t subgroup_matrix_m;
424
+ uint32_t subgroup_matrix_n;
425
+
426
+ uint32_t mul_mat_wg_size;
427
+ };
428
+
429
+ class ggml_webgpu_shader_lib {
430
+ wgpu::Device device;
431
+ pre_wgsl::Preprocessor preprocessor;
432
+
433
+ std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
434
+ std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
435
+ std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
436
+ std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
437
+ std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
438
+ std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
439
+ get_rows_pipelines; // src_type, vectorized
440
+ std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
441
+ unary_pipelines; // type/op/inplace
442
+ std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
443
+ scale_pipelines; // inplace
444
+ std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
445
+ pad_pipelines; // circular/non-circular
446
+ std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
447
+ binary_pipelines; // type/op/inplace/overlap
448
+ std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
449
+ concat_pipelines; // type
450
+ std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
451
+ repeat_pipelines; // type
452
+ std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
453
+ flash_attn_pipelines;
454
+ std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
455
+ webgpu_pipeline,
456
+ ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
457
+ mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec)
458
+ std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
459
+ mul_mat_vec_pipelines; // fast mat-vec (n==1)
460
+ std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
461
+ mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
462
+
463
+ std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
464
+ set_rows_pipelines;
465
+
466
+ public:
467
+ ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
468
+
469
+ webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
470
+ auto it = sum_rows_pipelines.find(1);
471
+ if (it != sum_rows_pipelines.end()) {
472
+ return it->second;
473
+ }
474
+ std::vector<std::string> defines;
475
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
476
+
477
+ auto processed = preprocessor.preprocess(wgsl_sum_rows, defines);
478
+ sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
479
+ return sum_rows_pipelines[1];
480
+ }
481
+
482
+ webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
483
+ bool vec4 = context.src0->ne[0] % 4 == 0;
484
+
485
+ auto it = argmax_pipelines.find(vec4);
486
+ if (it != argmax_pipelines.end()) {
487
+ return it->second;
488
+ }
489
+ std::string variant = "argmax";
490
+ std::vector<std::string> defines;
491
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
492
+ if (vec4) {
493
+ defines.push_back("VEC4");
494
+ variant += "_vec4";
495
+ }
496
+
497
+ auto processed = preprocessor.preprocess(wgsl_argmax, defines);
498
+ argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
499
+ return argmax_pipelines.at(vec4);
500
+ }
501
+
502
+ webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
503
+ ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
504
+ .vec4 = context.src0->ne[0] % 4 == 0,
505
+ .i64_idx = context.src1->type == GGML_TYPE_I64 };
506
+
507
+ auto it = set_rows_pipelines.find(key);
508
+ if (it != set_rows_pipelines.end()) {
509
+ return it->second;
510
+ }
511
+
512
+ std::vector<std::string> defines;
513
+ std::string variant = "set_rows";
514
+
515
+ switch (context.dst->type) {
516
+ case GGML_TYPE_F32:
517
+ defines.push_back("DST_F32");
518
+ variant += "_dstf32";
519
+ break;
520
+ case GGML_TYPE_F16:
521
+ defines.push_back("DST_F16");
522
+ variant += "_dstf16";
523
+ break;
524
+ default:
525
+ GGML_ABORT("Unsupported dst type for set_rows shader");
526
+ }
527
+
528
+ if (key.vec4) {
529
+ defines.push_back("VEC4");
530
+ variant += "_vec4";
531
+ }
532
+ if (key.i64_idx) {
533
+ defines.push_back("I64_IDX");
534
+ variant += "_i64idx";
535
+ }
536
+
537
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
538
+
539
+ auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
540
+ auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
541
+ decisions->vec4 = key.vec4;
542
+ decisions->i64_idx = key.i64_idx;
543
+ decisions->wg_size = context.max_wg_size;
544
+ set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
545
+ set_rows_pipelines[key].context = decisions;
546
+ return set_rows_pipelines[key];
547
+ }
548
+
549
+ webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
550
+ auto it = cumsum_pipelines.find(1);
551
+ if (it != cumsum_pipelines.end()) {
552
+ return it->second;
553
+ }
554
+
555
+ std::vector<std::string> defines;
556
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
557
+
558
+ auto processed = preprocessor.preprocess(wgsl_cumsum, defines);
559
+ cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
560
+ return cumsum_pipelines[1];
561
+ }
562
+
563
+ webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
564
+ bool is_top_k = context.dst->op == GGML_OP_TOP_K;
565
+ // ascending order is 0, descending order is 1
566
+ const int32_t order =
567
+ is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
568
+
569
+ auto it = argsort_pipelines.find(order);
570
+ if (it != argsort_pipelines.end()) {
571
+ return it->second;
572
+ }
573
+
574
+ std::vector<std::string> defines;
575
+ std::string variant = "argsort";
576
+ defines.push_back(std::string("ORDER=") + std::to_string(order));
577
+ variant += std::string("_order") + std::to_string(order);
578
+ uint32_t wg_size = 1;
579
+ while (wg_size * 2 <= context.max_wg_size &&
580
+ wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
581
+ wg_size *= 2;
582
+ }
583
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
584
+ auto processed = preprocessor.preprocess(wgsl_argsort, defines);
585
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
586
+ decisions->wg_size = wg_size;
587
+ argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
588
+ argsort_pipelines[order].context = decisions;
589
+ return argsort_pipelines[order];
590
+ }
591
+
592
+ webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
593
+ bool is_top_k = context.dst->op == GGML_OP_TOP_K;
594
+ // ascending order is 0, descending order is 1
595
+ const int32_t order =
596
+ is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
597
+
598
+ auto it = argsort_merge_pipelines.find(order);
599
+ if (it != argsort_merge_pipelines.end()) {
600
+ return it->second;
601
+ }
602
+
603
+ std::vector<std::string> defines;
604
+ std::string variant = "argsort_merge";
605
+ defines.push_back(std::string("ORDER=") + std::to_string(order));
606
+ variant += std::string("_order") + std::to_string(order);
607
+ uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
608
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
609
+
610
+ auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines);
611
+ argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
612
+ return argsort_merge_pipelines[order];
613
+ }
614
+
615
+ webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
616
+ const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
617
+ ggml_webgpu_get_rows_pipeline_key key = {
618
+ .src_type = context.src0->type,
619
+ .vectorized = (int) vectorized,
620
+ };
621
+
622
+ auto it = get_rows_pipelines.find(key);
623
+ if (it != get_rows_pipelines.end()) {
624
+ return it->second;
625
+ }
626
+
627
+ std::vector<std::string> defines;
628
+ std::string variant = "get_rows";
629
+
630
+ const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
631
+ const char * type_str = type_traits->type_name;
632
+
633
+ switch (key.src_type) {
634
+ case GGML_TYPE_F32:
635
+ if (key.vectorized) {
636
+ defines.push_back("F32_VEC");
637
+ defines.push_back("SRC_TYPE=vec4<f32>");
638
+ defines.push_back("DST_TYPE=vec4<f32>");
639
+ defines.push_back("BLOCK_SIZE=4u");
640
+ } else {
641
+ defines.push_back("F32");
642
+ defines.push_back("SRC_TYPE=f32");
643
+ defines.push_back("DST_TYPE=f32");
644
+ defines.push_back("BLOCK_SIZE=1u");
645
+ }
646
+ variant += "_f32";
647
+ break;
648
+ case GGML_TYPE_F16:
649
+ defines.push_back("F16");
650
+ defines.push_back("SRC_TYPE=f16");
651
+ defines.push_back("DST_TYPE=f32");
652
+ defines.push_back("BLOCK_SIZE=1u");
653
+ variant += "_f16";
654
+ break;
655
+ case GGML_TYPE_I32:
656
+ defines.push_back("I32");
657
+ defines.push_back("SRC_TYPE=i32");
658
+ defines.push_back("DST_TYPE=i32");
659
+ defines.push_back("BLOCK_SIZE=1u");
660
+ variant += "_i32";
661
+ break;
662
+ default:
663
+ {
664
+ std::string type_upper = type_str;
665
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
666
+
667
+ defines.push_back("BYTE_HELPERS");
668
+ defines.push_back(type_upper + "_T");
669
+ defines.push_back(type_upper);
670
+ defines.push_back(type_upper + "_SCALE_MIN");
671
+ defines.push_back(type_upper + "_TABLES");
672
+ defines.push_back(type_upper + "_GRID");
673
+
674
+ variant += "_";
675
+ variant += type_str;
676
+
677
+ defines.push_back(std::string("SRC_TYPE=") + type_str);
678
+ defines.push_back("DST_TYPE=f32");
679
+
680
+ if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
681
+ key.src_type == GGML_TYPE_IQ4_NL) {
682
+ defines.push_back("BLOCK_SIZE=32u");
683
+ } else if (key.src_type >= GGML_TYPE_Q2_K) {
684
+ defines.push_back("BLOCK_SIZE=256u");
685
+ } else {
686
+ defines.push_back("BLOCK_SIZE=1u");
687
+ }
688
+ break;
689
+ }
690
+ }
691
+
692
+ if (key.vectorized) {
693
+ variant += "_vec";
694
+ }
695
+
696
+ defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
697
+
698
+ auto processed = preprocessor.preprocess(wgsl_get_rows, defines);
699
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
700
+ decisions->wg_size = context.max_wg_size;
701
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
702
+ pipeline.context = decisions;
703
+ get_rows_pipelines[key] = pipeline;
704
+ return get_rows_pipelines[key];
705
+ }
706
+
707
+ webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
708
+ ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
709
+
710
+ auto it = scale_pipelines.find(key);
711
+ if (it != scale_pipelines.end()) {
712
+ return it->second;
713
+ }
714
+
715
+ std::vector<std::string> defines;
716
+ std::string variant = "scale";
717
+
718
+ if (key.inplace) {
719
+ defines.push_back("INPLACE");
720
+ variant += "_inplace";
721
+ }
722
+
723
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
724
+
725
+ auto processed = preprocessor.preprocess(wgsl_scale, defines);
726
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
727
+ decisions->wg_size = context.max_wg_size;
728
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
729
+ pipeline.context = decisions;
730
+ scale_pipelines[key] = pipeline;
731
+ return scale_pipelines[key];
732
+ }
733
+
734
+ webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
735
+ ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
736
+
737
+ auto it = pad_pipelines.find(key);
738
+ if (it != pad_pipelines.end()) {
739
+ return it->second;
740
+ }
741
+
742
+ std::vector<std::string> defines;
743
+ std::string variant = "pad";
744
+
745
+ if (key.circular) {
746
+ defines.push_back("CIRCULAR");
747
+ variant += "_circular";
748
+ }
749
+
750
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
751
+
752
+ auto processed = preprocessor.preprocess(wgsl_pad, defines);
753
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
754
+ decisions->wg_size = context.max_wg_size;
755
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
756
+ pipeline.context = decisions;
757
+ pad_pipelines[key] = pipeline;
758
+ return pad_pipelines[key];
759
+ }
760
+
761
+ webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
762
+ ggml_webgpu_mul_mat_vec_pipeline_key key = {
763
+ .src0_type = context.src0->type,
764
+ .src1_type = context.src1->type,
765
+ // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float
766
+ .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
767
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
768
+ 1 :
769
+ 0,
770
+ };
771
+
772
+ auto it = mul_mat_vec_pipelines.find(key);
773
+ if (it != mul_mat_vec_pipelines.end()) {
774
+ return it->second;
775
+ }
776
+
777
+ std::vector<std::string> defines;
778
+ std::string variant = "mul_mat_vec";
779
+
780
+ // src0 type (matrix row)
781
+ switch (context.src0->type) {
782
+ case GGML_TYPE_F32:
783
+ defines.push_back("SRC0_INNER_TYPE=f32");
784
+ defines.push_back("MUL_ACC_FLOAT");
785
+ variant += "_f32";
786
+ break;
787
+ case GGML_TYPE_F16:
788
+ defines.push_back("SRC0_INNER_TYPE=f16");
789
+ defines.push_back("MUL_ACC_FLOAT");
790
+ variant += "_f16";
791
+ break;
792
+ default:
793
+ {
794
+ // Quantized types: use helpers but accumulate in f16
795
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
796
+ std::string src0_name = src0_traits->type_name;
797
+ std::string type_upper = src0_name;
798
+ variant += "_" + src0_name;
799
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
800
+
801
+ defines.push_back("BYTE_HELPERS");
802
+ defines.push_back("MUL_ACC_" + type_upper);
803
+
804
+ // For fast path we always dequantize from f16 inside the shader
805
+ defines.push_back("SRC0_INNER_TYPE=f16");
806
+ break;
807
+ }
808
+ }
809
+
810
+ // src1 type (vector)
811
+ switch (context.src1->type) {
812
+ case GGML_TYPE_F32:
813
+ defines.push_back("SRC1_INNER_TYPE=f32");
814
+ variant += "_f32";
815
+ break;
816
+ case GGML_TYPE_F16:
817
+ defines.push_back("SRC1_INNER_TYPE=f16");
818
+ variant += "_f16";
819
+ break;
820
+ default:
821
+ GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
822
+ }
823
+
824
+ // VEC/SCALAR controls
825
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
826
+
827
+ uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
828
+ uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
829
+ uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
830
+
831
+ if (key.src0_type >= GGML_TYPE_Q2_K) {
832
+ tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
833
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
834
+ } else if (key.src0_type >= GGML_TYPE_Q4_0) {
835
+ tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
836
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
837
+ }
838
+
839
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
840
+ defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
841
+ defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
842
+
843
+ auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
844
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
845
+ decisions->wg_size = wg_size;
846
+ decisions->tile_k = tile_k;
847
+ decisions->outputs_per_wg = outputs_per_wg;
848
+ decisions->vec_size = key.vectorized ? 4 : 1;
849
+
850
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
851
+ pipeline.context = decisions;
852
+ mul_mat_vec_pipelines[key] = pipeline;
853
+ return mul_mat_vec_pipelines[key];
854
+ }
855
+
856
+ webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
857
+ ggml_webgpu_mul_mat_pipeline_key key = {
858
+ .src0_type = context.src0->type,
859
+ .src1_type = context.src1->type,
860
+ .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
861
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
862
+ 1 :
863
+ 0,
864
+ .use_subgroup_matrix = context.supports_subgroup_matrix
865
+ };
866
+
867
+ auto it = mul_mat_fast_pipelines.find(key);
868
+ if (it != mul_mat_fast_pipelines.end()) {
869
+ return it->second;
870
+ }
871
+
872
+ const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile;
873
+ std::vector<std::string> defines;
874
+ std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile";
875
+
876
+ // src1 type
877
+ switch (context.src1->type) {
878
+ case GGML_TYPE_F32:
879
+ defines.push_back("SRC1_INNER_TYPE=f32");
880
+ break;
881
+ case GGML_TYPE_F16:
882
+ defines.push_back("SRC1_INNER_TYPE=f16");
883
+ break;
884
+ default:
885
+ GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
886
+ }
887
+
888
+ // src0 type
889
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
890
+ const char * src0_name = src0_traits->type_name;
891
+
892
+ switch (context.src0->type) {
893
+ case GGML_TYPE_F32:
894
+ defines.push_back("SRC0_INNER_TYPE=f32");
895
+ defines.push_back("FLOAT");
896
+ defines.push_back("MUL_ACC_FLOAT");
897
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
898
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
899
+ variant += "_f32";
900
+ break;
901
+ case GGML_TYPE_F16:
902
+ defines.push_back("SRC0_INNER_TYPE=f16");
903
+ defines.push_back("FLOAT");
904
+ defines.push_back("MUL_ACC_FLOAT");
905
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
906
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
907
+ variant += "_f16";
908
+ break;
909
+ default:
910
+ {
911
+ std::string type_upper = src0_name;
912
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
913
+
914
+ defines.push_back("BYTE_HELPERS");
915
+ defines.push_back("MUL_ACC_" + type_upper);
916
+ defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
917
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
918
+
919
+ // Use f16 inside the shader for quantized types
920
+ defines.push_back("SRC0_INNER_TYPE=f16");
921
+
922
+ variant += std::string("_") + src0_name;
923
+ break;
924
+ }
925
+ }
926
+
927
+ // VEC/SCALAR controls
928
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
929
+
930
+ // Tiles
931
+ defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
932
+ defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
933
+ defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
934
+
935
+ // Subgroup matrix specifics
936
+ if (key.use_subgroup_matrix) {
937
+ defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
938
+ defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
939
+ defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
940
+ defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
941
+ defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
942
+ defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
943
+ defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
944
+ defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
945
+ }
946
+
947
+ // variant suffix for src1 type
948
+ variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
949
+ if (key.vectorized) {
950
+ variant += "_vectorized";
951
+ }
952
+
953
+ if (!key.use_subgroup_matrix) {
954
+ defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
955
+ defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
956
+ }
957
+
958
+ auto processed = preprocessor.preprocess(shader_src, defines);
959
+
960
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
961
+ decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
962
+ decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
963
+ decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
964
+ decisions->use_subgroup_matrix = key.use_subgroup_matrix;
965
+ if (key.use_subgroup_matrix) {
966
+ decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
967
+ decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
968
+ decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
969
+ decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
970
+ decisions->wg_size = context.max_subgroup_size;
971
+ } else {
972
+ decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
973
+ decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
974
+ decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
975
+ decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
976
+ }
977
+
978
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
979
+ pipeline.context = decisions;
980
+ mul_mat_fast_pipelines[key] = pipeline;
981
+ return mul_mat_fast_pipelines[key];
982
+ }
983
+
984
+ webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
985
+ ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,
986
+ .src1_type = context.src1->type };
987
+
988
+ auto it = mul_mat_legacy_pipelines.find(key);
989
+ if (it != mul_mat_legacy_pipelines.end()) {
990
+ return it->second;
991
+ }
992
+
993
+ std::vector<std::string> defines;
994
+ std::string variant = "mul_mat";
995
+
996
+ switch (context.src1->type) {
997
+ case GGML_TYPE_F32:
998
+ defines.push_back("SRC1_TYPE=f32");
999
+ variant += "_f32";
1000
+ break;
1001
+ case GGML_TYPE_F16:
1002
+ defines.push_back("SRC1_TYPE=f16");
1003
+ variant += "_f16";
1004
+ break;
1005
+ default:
1006
+ GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
1007
+ }
1008
+
1009
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
1010
+ const char * src0_name = src0_traits->type_name;
1011
+
1012
+ switch (context.src0->type) {
1013
+ case GGML_TYPE_F32:
1014
+ defines.push_back("SRC0_TYPE=f32");
1015
+ defines.push_back("FLOAT");
1016
+ variant += "_f32";
1017
+ break;
1018
+ case GGML_TYPE_F16:
1019
+ defines.push_back("SRC0_TYPE=f16");
1020
+ defines.push_back("FLOAT");
1021
+ variant += "_f16";
1022
+ break;
1023
+ default:
1024
+ {
1025
+ // quantized types
1026
+ std::string type_upper = src0_name;
1027
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
1028
+
1029
+ defines.push_back(std::string("SRC0_TYPE=") + src0_name);
1030
+ defines.push_back("BYTE_HELPERS");
1031
+ defines.push_back(type_upper + "_T");
1032
+ defines.push_back(type_upper);
1033
+ defines.push_back(type_upper + "_SCALE_MIN");
1034
+ defines.push_back(type_upper + "_TABLES");
1035
+ defines.push_back(type_upper + "_GRID");
1036
+
1037
+ variant += std::string("_") + src0_name;
1038
+ break;
1039
+ }
1040
+ }
1041
+
1042
+ auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
1043
+
1044
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1045
+ decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
1046
+
1047
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1048
+ pipeline.context = decisions;
1049
+ mul_mat_legacy_pipelines[key] = pipeline;
1050
+ return mul_mat_legacy_pipelines[key];
1051
+ }
1052
+
1053
+ webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
1054
+ const bool is_unary = context.dst->op == GGML_OP_UNARY;
1055
+ const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
1056
+ ggml_webgpu_unary_pipeline_key key = {
1057
+ .type = context.dst->type,
1058
+ .op = op,
1059
+ .is_unary = is_unary,
1060
+ .inplace = context.inplace,
1061
+ };
1062
+
1063
+ auto it = unary_pipelines.find(key);
1064
+ if (it != unary_pipelines.end()) {
1065
+ return it->second;
1066
+ }
1067
+
1068
+ std::vector<std::string> defines;
1069
+ std::string variant =
1070
+ key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op);
1071
+ defines.push_back(variant);
1072
+
1073
+ switch (key.type) {
1074
+ case GGML_TYPE_F32:
1075
+ defines.push_back("TYPE_F32");
1076
+ variant += "_f32";
1077
+ break;
1078
+ case GGML_TYPE_F16:
1079
+ defines.push_back("TYPE_F16");
1080
+ variant += "_f16";
1081
+ break;
1082
+ default:
1083
+ GGML_ABORT("Unsupported type for unary shader");
1084
+ }
1085
+
1086
+ if (key.inplace) {
1087
+ defines.push_back("INPLACE");
1088
+ variant += "_inplace";
1089
+ }
1090
+
1091
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1092
+
1093
+ auto processed = preprocessor.preprocess(wgsl_unary, defines);
1094
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1095
+ decisions->wg_size = context.max_wg_size;
1096
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1097
+ pipeline.context = decisions;
1098
+ unary_pipelines[key] = pipeline;
1099
+ return unary_pipelines[key];
1100
+ }
1101
+
1102
+ webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
1103
+ ggml_webgpu_binary_pipeline_key key = {
1104
+ .type = context.dst->type,
1105
+ .op = context.dst->op,
1106
+ .inplace = context.inplace,
1107
+ .overlap = context.overlap,
1108
+ .src_overlap = context.src_overlap,
1109
+ };
1110
+
1111
+ auto it = binary_pipelines.find(key);
1112
+ if (it != binary_pipelines.end()) {
1113
+ return it->second;
1114
+ }
1115
+
1116
+ std::vector<std::string> defines;
1117
+ std::string op_name = ggml_op_name((ggml_op) key.op);
1118
+ std::string variant = op_name;
1119
+
1120
+ defines.push_back(std::string("OP_") + op_name);
1121
+
1122
+ switch (key.type) {
1123
+ case GGML_TYPE_F32:
1124
+ defines.push_back("TYPE_F32");
1125
+ variant += "_f32";
1126
+ break;
1127
+ case GGML_TYPE_F16:
1128
+ defines.push_back("TYPE_F16");
1129
+ variant += "_f16";
1130
+ break;
1131
+ default:
1132
+ GGML_ABORT("Unsupported type for binary shader");
1133
+ }
1134
+
1135
+ if (key.inplace) {
1136
+ defines.push_back("INPLACE");
1137
+ variant += "_inplace";
1138
+ } else if (key.overlap) {
1139
+ defines.push_back("OVERLAP");
1140
+ variant += "_overlap";
1141
+ } else if (key.src_overlap) {
1142
+ defines.push_back("SRC_OVERLAP");
1143
+ variant += "_src_overlap";
1144
+ }
1145
+
1146
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1147
+
1148
+ auto processed = preprocessor.preprocess(wgsl_binary, defines);
1149
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1150
+ decisions->wg_size = context.max_wg_size;
1151
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1152
+ pipeline.context = decisions;
1153
+ binary_pipelines[key] = pipeline;
1154
+ return binary_pipelines[key];
1155
+ }
1156
+
1157
+ webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
1158
+ ggml_webgpu_concat_pipeline_key key = {
1159
+ .type = context.dst->type,
1160
+ };
1161
+
1162
+ auto it = concat_pipelines.find(key);
1163
+ if (it != concat_pipelines.end()) {
1164
+ return it->second;
1165
+ }
1166
+
1167
+ std::vector<std::string> defines;
1168
+ std::string variant = "concat";
1169
+
1170
+ switch (key.type) {
1171
+ case GGML_TYPE_F32:
1172
+ defines.push_back("TYPE_F32");
1173
+ variant += "_f32";
1174
+ break;
1175
+ case GGML_TYPE_I32:
1176
+ defines.push_back("TYPE_I32");
1177
+ variant += "_i32";
1178
+ break;
1179
+ default:
1180
+ GGML_ABORT("Unsupported type for concat shader");
1181
+ }
1182
+
1183
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1184
+
1185
+ auto processed = preprocessor.preprocess(wgsl_concat, defines);
1186
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1187
+ decisions->wg_size = context.max_wg_size;
1188
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1189
+ pipeline.context = decisions;
1190
+ concat_pipelines[key] = pipeline;
1191
+ return concat_pipelines[key];
1192
+ }
1193
+
1194
+ webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
1195
+ ggml_webgpu_repeat_pipeline_key key = {
1196
+ .type = context.dst->type,
1197
+ };
1198
+
1199
+ auto it = repeat_pipelines.find(key);
1200
+ if (it != repeat_pipelines.end()) {
1201
+ return it->second;
1202
+ }
1203
+
1204
+ std::vector<std::string> defines;
1205
+ std::string variant = "repeat";
1206
+
1207
+ switch (key.type) {
1208
+ case GGML_TYPE_F32:
1209
+ defines.push_back("TYPE_F32");
1210
+ variant += "_f32";
1211
+ break;
1212
+ case GGML_TYPE_I32:
1213
+ defines.push_back("TYPE_I32");
1214
+ variant += "_i32";
1215
+ break;
1216
+ case GGML_TYPE_I16:
1217
+ defines.push_back("TYPE_I16");
1218
+ variant += "_i16";
1219
+ break;
1220
+ default:
1221
+ GGML_ABORT("Unsupported type for repeat shader");
1222
+ }
1223
+
1224
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1225
+
1226
+ auto processed = preprocessor.preprocess(wgsl_repeat, defines);
1227
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1228
+ decisions->wg_size = context.max_wg_size;
1229
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1230
+ pipeline.context = decisions;
1231
+ repeat_pipelines[key] = pipeline;
1232
+ return repeat_pipelines[key];
1233
+ }
1234
+
1235
+ webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
1236
+ const bool has_mask = context.src3 != nullptr;
1237
+ const bool has_sinks = context.src4 != nullptr;
1238
+
1239
+ bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
1240
+ (context.src1->ne[1] % context.sg_mat_n == 0);
1241
+
1242
+ ggml_webgpu_flash_attn_pipeline_key key = {
1243
+ .kv_type = context.src1->type,
1244
+ .head_dim_qk = (uint32_t) context.src0->ne[0],
1245
+ .head_dim_v = (uint32_t) context.src2->ne[0],
1246
+ .kv_direct = kv_direct,
1247
+ .has_mask = has_mask,
1248
+ .has_sinks = has_sinks,
1249
+ .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
1250
+ };
1251
+
1252
+ auto it = flash_attn_pipelines.find(key);
1253
+ if (it != flash_attn_pipelines.end()) {
1254
+ return it->second;
1255
+ }
1256
+
1257
+ std::vector<std::string> defines;
1258
+ std::string variant = "flash_attn";
1259
+
1260
+ switch (key.kv_type) {
1261
+ case GGML_TYPE_F32:
1262
+ defines.push_back("KV_F32");
1263
+ break;
1264
+ case GGML_TYPE_F16:
1265
+ defines.push_back("KV_F16");
1266
+ break;
1267
+ case GGML_TYPE_Q4_0:
1268
+ defines.push_back("KV_Q4_0");
1269
+ break;
1270
+ case GGML_TYPE_Q8_0:
1271
+ defines.push_back("KV_Q8_0");
1272
+ break;
1273
+ default:
1274
+ GGML_ABORT("Unsupported KV type for flash attention shader");
1275
+ }
1276
+ variant += std::string("_") + ggml_type_name(key.kv_type);
1277
+
1278
+ if (key.has_mask) {
1279
+ defines.push_back("MASK");
1280
+ variant += "_mask";
1281
+ }
1282
+ if (key.has_sinks) {
1283
+ defines.push_back("SINKS");
1284
+ variant += "_sinks";
1285
+ }
1286
+ if (key.uses_logit_softcap) {
1287
+ defines.push_back("LOGIT_SOFTCAP");
1288
+ variant += "_lgsc";
1289
+ }
1290
+ if (key.kv_direct) {
1291
+ defines.push_back("KV_DIRECT");
1292
+ variant += "_kvdirect";
1293
+ }
1294
+
1295
+ defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
1296
+ variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
1297
+
1298
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
1299
+ variant += std::string("_hsv") + std::to_string(key.head_dim_v);
1300
+
1301
+ defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
1302
+ defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
1303
+ defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
1304
+
1305
+ uint32_t q_tile = context.sg_mat_m;
1306
+ uint32_t kv_tile =
1307
+ std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
1308
+ context.wg_mem_limit_bytes, context.max_subgroup_size }),
1309
+ context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
1310
+ if (key.kv_direct) {
1311
+ while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
1312
+ kv_tile -= context.sg_mat_n;
1313
+ }
1314
+ }
1315
+
1316
+ defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
1317
+ defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
1318
+
1319
+ uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
1320
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1321
+
1322
+ auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
1323
+ auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
1324
+ decisions->q_tile = q_tile;
1325
+ decisions->kv_tile = kv_tile;
1326
+ decisions->wg_size = wg_size;
1327
+
1328
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1329
+ pipeline.context = decisions;
1330
+ flash_attn_pipelines[key] = pipeline;
1331
+ return flash_attn_pipelines[key];
1332
+ }
1333
+
1334
+ private:
1335
+ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
1336
+ std::string shader_code,
1337
+ std::string label) {
1338
+ wgpu::ShaderSourceWGSL shader_source;
1339
+ shader_source.code = shader_code.c_str();
1340
+
1341
+ wgpu::ShaderModuleDescriptor shader_desc;
1342
+ shader_desc.nextInChain = &shader_source;
1343
+
1344
+ wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
1345
+
1346
+ wgpu::ComputePipelineDescriptor pipeline_desc;
1347
+ pipeline_desc.label = label.c_str();
1348
+ pipeline_desc.compute.module = shader_module;
1349
+ pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
1350
+ pipeline_desc.layout = nullptr; // nullptr means auto layout
1351
+ return { device.CreateComputePipeline(&pipeline_desc), label };
1352
+ }
1353
+
1354
+ static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
1355
+ const size_t limit_bytes = context.wg_mem_limit_bytes;
1356
+ const size_t q_tile = context.sg_mat_m;
1357
+ const size_t base_q_bytes =
1358
+ (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
1359
+ 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
1360
+ size_t bytes_per_kv = 0;
1361
+ if (!context.key.kv_direct) {
1362
+ bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
1363
+ }
1364
+ if (context.key.has_mask) {
1365
+ bytes_per_kv += q_tile;
1366
+ }
1367
+ bytes_per_kv += q_tile;
1368
+ bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
1369
+ const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
1370
+ return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
1371
+ }
1372
+ };
168
1373
 
169
1374
  #endif // GGML_WEBGPU_SHADER_LIB_HPP