whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -1,169 +1,3293 @@
1
1
  #ifndef GGML_WEBGPU_SHADER_LIB_HPP
2
2
  #define GGML_WEBGPU_SHADER_LIB_HPP
3
3
 
4
+ #include "ggml-impl.h"
5
+ #include "ggml-wgsl-shaders.hpp"
4
6
  #include "ggml.h"
5
7
  #include "pre_wgsl.hpp"
6
8
 
9
+ #include <webgpu/webgpu_cpp.h>
10
+
11
+ #include <algorithm>
12
+ #include <memory>
7
13
  #include <string>
14
+ #include <unordered_map>
8
15
  #include <vector>
9
16
 
10
17
  #define GGML_WEBGPU_F16_SIZE_BYTES 2
11
18
  #define GGML_WEBGPU_F32_SIZE_BYTES 4
19
+ #define GGML_WEBGPU_I32_SIZE_BYTES 4
12
20
  #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
21
+ #define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u
22
+ #define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u
23
+ #define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u
13
24
  #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
14
25
  // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
15
26
  #define GGML_WEBGPU_KV_SEQ_PAD 256u
16
27
 
17
- struct ggml_webgpu_flash_attn_shader_lib_context {
18
- ggml_type kv_type;
28
+ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
29
+
30
+ // Matrix multiplication parameters
31
+
32
+ // Register tiling parameters
33
+ #define WEBGPU_MUL_MAT_TILE_M 4
34
+ #define WEBGPU_MUL_MAT_TILE_N 4
35
+ #define WEBGPU_MUL_MAT_WG_SIZE_M 8
36
+ #define WEBGPU_MUL_MAT_WG_SIZE_N 8
37
+ #define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
38
+ #define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
39
+
40
+ // Subgroup matrix parameters
41
+ // The number of subgroups in the M dimension
42
+ #define WEBGPU_MUL_MAT_SUBGROUP_M 2
43
+ // The number of subgroups in the N dimension
44
+ #define WEBGPU_MUL_MAT_SUBGROUP_N 4
45
+ // The number of subgroup matrices each subgroup accumulates over
46
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
47
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
48
+ #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
49
+ #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
50
+
51
+ // Matrix-vector multiplication parameters
52
+ #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
53
+
54
+ #define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
55
+ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
56
+ #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
57
+
58
+ // default size for reg-tile matrix multiplication
59
+ #define WEBGPU_MUL_MAT_WG_SIZE 256
60
+
61
+ // Same hash combine function as in boost
62
+ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
63
+ seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
64
+ }
65
+
66
+ // Calculates base address of a tensor ignoring the fake base pointer
67
+ inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
68
+ const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
69
+ return (uintptr_t) base_tensor->data + tensor->view_offs;
70
+ }
71
+
72
+ inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
73
+ return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
74
+ }
75
+
76
+ inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
77
+ return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
78
+ ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
79
+ }
80
+
81
+ struct ggml_webgpu_shader_lib_context {
82
+ ggml_tensor * src0;
83
+ ggml_tensor * src1;
84
+ ggml_tensor * src2;
85
+ ggml_tensor * src3;
86
+ ggml_tensor * src4;
87
+ ggml_tensor * src5;
88
+ ggml_tensor * dst;
89
+
90
+ uint32_t max_wg_size;
91
+ size_t wg_mem_limit_bytes = 0;
92
+ bool supports_subgroups = false;
93
+ bool supports_subgroup_matrix = false;
94
+ uint32_t sg_mat_m = 0;
95
+ uint32_t sg_mat_n = 0;
96
+ uint32_t sg_mat_k = 0;
97
+ uint32_t min_subgroup_size = 0;
98
+ uint32_t max_subgroup_size = 0;
99
+ bool supports_dot_product = false;
100
+ std::string vendor;
101
+ };
102
+
103
+ struct webgpu_pipeline {
104
+ wgpu::ComputePipeline pipeline;
105
+ std::string name;
106
+ std::shared_ptr<void> context = nullptr;
107
+ };
108
+
109
+ struct ggml_webgpu_generic_shader_decisions {
110
+ uint32_t wg_size = 0;
111
+ bool inplace = false;
112
+ };
113
+
114
+ struct ggml_webgpu_binary_shader_decisions {
115
+ uint32_t wg_size = 0;
116
+ bool inplace = false;
117
+ bool overlap = false;
118
+ bool src_overlap = false;
119
+ };
120
+
121
+ struct ggml_webgpu_processed_shader {
122
+ std::string wgsl;
123
+ std::string variant;
124
+ std::shared_ptr<void> decisions;
125
+ };
126
+
127
+ struct ggml_webgpu_ssm_conv_shader_decisions {
128
+ uint32_t block_size;
129
+ uint32_t tokens_per_wg;
130
+ };
131
+
132
+ struct ggml_webgpu_ssm_scan_pipeline_key {
133
+ int type;
134
+ int d_state;
135
+ bool xbc_overlap;
136
+
137
+ bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
138
+ return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
139
+ }
140
+ };
141
+
142
+ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
143
+ size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const {
144
+ size_t seed = 0;
145
+ ggml_webgpu_hash_combine(seed, key.type);
146
+ ggml_webgpu_hash_combine(seed, key.d_state);
147
+ ggml_webgpu_hash_combine(seed, key.xbc_overlap);
148
+ return seed;
149
+ }
150
+ };
151
+
152
+ struct ggml_webgpu_ssm_scan_shader_decisions {
153
+ uint32_t wg_size;
154
+ uint32_t tokens_per_tile;
155
+ bool xbc_overlap = false;
156
+ };
157
+
158
+ /** Argsort **/
159
+
160
+ struct ggml_webgpu_argsort_shader_lib_context {
161
+ uint32_t max_wg_size;
162
+ size_t wg_mem_limit_bytes;
163
+ int32_t order;
164
+ };
165
+
166
+ /** Set Rows **/
167
+
168
+ struct ggml_webgpu_set_rows_pipeline_key {
169
+ int dst_type;
170
+ int vec4;
171
+ int i64_idx;
172
+ int pair_blocks;
173
+
174
+ bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
175
+ return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx &&
176
+ pair_blocks == other.pair_blocks;
177
+ }
178
+ };
179
+
180
+ struct ggml_webgpu_set_rows_pipeline_key_hash {
181
+ size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
182
+ size_t seed = 0;
183
+ ggml_webgpu_hash_combine(seed, key.dst_type);
184
+ ggml_webgpu_hash_combine(seed, key.vec4);
185
+ ggml_webgpu_hash_combine(seed, key.i64_idx);
186
+ ggml_webgpu_hash_combine(seed, key.pair_blocks);
187
+ return seed;
188
+ }
189
+ };
190
+
191
+ struct ggml_webgpu_set_rows_shader_decisions {
192
+ bool vec4;
193
+ bool i64_idx;
194
+ bool pair_blocks;
195
+ uint32_t wg_size;
196
+ };
197
+
198
+ /** Set **/
199
+
200
+ struct ggml_webgpu_set_pipeline_key {
201
+ ggml_type type;
202
+ bool inplace;
203
+
204
+ bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
205
+ return type == other.type && inplace == other.inplace;
206
+ }
207
+ };
208
+
209
+ struct ggml_webgpu_set_pipeline_key_hash {
210
+ size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
211
+ size_t seed = 0;
212
+ ggml_webgpu_hash_combine(seed, key.type);
213
+ ggml_webgpu_hash_combine(seed, key.inplace);
214
+ return seed;
215
+ }
216
+ };
217
+
218
+ /** Get Rows **/
219
+
220
+ struct ggml_webgpu_get_rows_pipeline_key {
221
+ ggml_type src_type;
222
+ int vectorized;
223
+
224
+ bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
225
+ return src_type == other.src_type && vectorized == other.vectorized;
226
+ }
227
+ };
228
+
229
+ struct ggml_webgpu_get_rows_pipeline_key_hash {
230
+ size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
231
+ size_t seed = 0;
232
+ ggml_webgpu_hash_combine(seed, key.src_type);
233
+ ggml_webgpu_hash_combine(seed, key.vectorized);
234
+ return seed;
235
+ }
236
+ };
237
+
238
+ /** Row Norm **/
239
+
240
+ struct ggml_webgpu_row_norm_pipeline_key {
241
+ ggml_op op;
242
+ ggml_type src_type;
243
+ ggml_type dst_type;
244
+ bool inplace;
245
+
246
+ bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
247
+ return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
248
+ }
249
+ };
250
+
251
+ struct ggml_webgpu_row_norm_pipeline_key_hash {
252
+ size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
253
+ size_t seed = 0;
254
+ ggml_webgpu_hash_combine(seed, key.op);
255
+ ggml_webgpu_hash_combine(seed, key.src_type);
256
+ ggml_webgpu_hash_combine(seed, key.dst_type);
257
+ ggml_webgpu_hash_combine(seed, key.inplace);
258
+ return seed;
259
+ }
260
+ };
261
+
262
+ /** RMS_NORM + MUL **/
263
+
264
+ struct ggml_webgpu_rms_norm_mul_pipeline_key {
265
+ bool inplace; // rn_src == dst
266
+ bool overlap; // mul_src == dst
267
+ bool src_overlap; // rn_src == mul_src
268
+
269
+ bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
270
+ return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
271
+ }
272
+ };
273
+
274
+ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
275
+ size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
276
+ size_t seed = 0;
277
+ ggml_webgpu_hash_combine(seed, key.inplace);
278
+ ggml_webgpu_hash_combine(seed, key.overlap);
279
+ ggml_webgpu_hash_combine(seed, key.src_overlap);
280
+ return seed;
281
+ }
282
+ };
283
+
284
+ struct ggml_webgpu_rms_norm_mul_shader_decisions {
285
+ uint32_t wg_size = 0;
286
+ bool inplace = false;
287
+ bool overlap = false;
288
+ bool src_overlap = false;
289
+ };
290
+
291
+ /** Pad **/
292
+ struct ggml_webgpu_pad_pipeline_key {
293
+ bool circular;
294
+
295
+ bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
296
+ };
297
+
298
+ struct ggml_webgpu_pad_pipeline_key_hash {
299
+ size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
300
+ size_t seed = 0;
301
+ ggml_webgpu_hash_combine(seed, key.circular);
302
+ return seed;
303
+ }
304
+ };
305
+
306
+ /** Solve Tri **/
307
+ struct ggml_webgpu_solve_tri_pipeline_key {
308
+ int type;
309
+ int n;
310
+ int k;
311
+
312
+ bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
313
+ return type == other.type && n == other.n && k == other.k;
314
+ }
315
+ };
316
+
317
+ struct ggml_webgpu_solve_tri_pipeline_key_hash {
318
+ size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
319
+ size_t seed = 0;
320
+ ggml_webgpu_hash_combine(seed, key.type);
321
+ ggml_webgpu_hash_combine(seed, key.n);
322
+ ggml_webgpu_hash_combine(seed, key.k);
323
+ return seed;
324
+ }
325
+ };
326
+
327
+ /** SSM Conv **/
328
+ struct ggml_webgpu_ssm_conv_pipeline_key {
329
+ int type;
330
+ int vectorized;
331
+
332
+ bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
333
+ return type == other.type && vectorized == other.vectorized;
334
+ }
335
+ };
336
+
337
+ /** CONV 2D */
338
+ struct ggml_webgpu_conv2d_pipeline_key {
339
+ ggml_type weight_type;
340
+ ggml_type input_type;
341
+ ggml_type output_type;
342
+
343
+ bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const {
344
+ return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type;
345
+ }
346
+ };
347
+
348
+ struct ggml_webgpu_conv2d_pipeline_key_hash {
349
+ size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const {
350
+ size_t seed = 0;
351
+ ggml_webgpu_hash_combine(seed, key.weight_type);
352
+ ggml_webgpu_hash_combine(seed, key.input_type);
353
+ ggml_webgpu_hash_combine(seed, key.output_type);
354
+ return seed;
355
+ }
356
+ };
357
+
358
+ /** Im2Col **/
359
+ struct ggml_webgpu_im2col_pipeline_key {
360
+ ggml_type input_type;
361
+ ggml_type output_type;
362
+
363
+ bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const {
364
+ return input_type == other.input_type && output_type == other.output_type;
365
+ }
366
+ };
367
+
368
+ struct ggml_webgpu_im2col_pipeline_key_hash {
369
+ size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const {
370
+ size_t seed = 0;
371
+ ggml_webgpu_hash_combine(seed, key.input_type);
372
+ ggml_webgpu_hash_combine(seed, key.output_type);
373
+ return seed;
374
+ }
375
+ };
376
+
377
+ /** Gated Delta Net **/
378
+ struct ggml_webgpu_gated_delta_net_pipeline_key {
379
+ int type;
380
+ int s_v;
381
+ int kda;
382
+
383
+ bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
384
+ return type == other.type && s_v == other.s_v && kda == other.kda;
385
+ }
386
+ };
387
+
388
+ struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
389
+ size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
390
+ size_t seed = 0;
391
+ ggml_webgpu_hash_combine(seed, key.type);
392
+ ggml_webgpu_hash_combine(seed, key.s_v);
393
+ ggml_webgpu_hash_combine(seed, key.kda);
394
+ return seed;
395
+ }
396
+ };
397
+
398
+ struct ggml_webgpu_ssm_conv_pipeline_key_hash {
399
+ size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
400
+ size_t seed = 0;
401
+ ggml_webgpu_hash_combine(seed, key.type);
402
+ ggml_webgpu_hash_combine(seed, key.vectorized);
403
+ return seed;
404
+ }
405
+ };
406
+
407
+ /** Scale **/
408
+
409
+ struct ggml_webgpu_scale_pipeline_key {
410
+ int inplace;
411
+
412
+ bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
413
+ };
414
+
415
+ struct ggml_webgpu_scale_pipeline_key_hash {
416
+ size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
417
+ size_t seed = 0;
418
+ ggml_webgpu_hash_combine(seed, key.inplace);
419
+ return seed;
420
+ }
421
+ };
422
+
423
+ /** Upscale **/
424
+
425
+ struct ggml_webgpu_upscale_pipeline_key {
426
+ ggml_type input_type;
427
+ ggml_type output_type;
428
+ uint32_t base_mode;
429
+ bool antialias;
430
+
431
+ bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const {
432
+ return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode &&
433
+ antialias == other.antialias;
434
+ }
435
+ };
436
+
437
+ struct ggml_webgpu_upscale_pipeline_key_hash {
438
+ size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const {
439
+ size_t seed = 0;
440
+ ggml_webgpu_hash_combine(seed, key.input_type);
441
+ ggml_webgpu_hash_combine(seed, key.output_type);
442
+ ggml_webgpu_hash_combine(seed, key.base_mode);
443
+ ggml_webgpu_hash_combine(seed, key.antialias);
444
+ return seed;
445
+ }
446
+ };
447
+
448
+ /** Concat **/
449
+
450
+ struct ggml_webgpu_concat_pipeline_key {
451
+ int type;
452
+ bool src_overlap;
453
+
454
+ bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
455
+ return type == other.type && src_overlap == other.src_overlap;
456
+ }
457
+ };
458
+
459
+ struct ggml_webgpu_concat_pipeline_key_hash {
460
+ size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
461
+ size_t seed = 0;
462
+ ggml_webgpu_hash_combine(seed, key.type);
463
+ ggml_webgpu_hash_combine(seed, key.src_overlap);
464
+ return seed;
465
+ }
466
+ };
467
+
468
+ /** Repeat **/
469
+
470
+ struct ggml_webgpu_repeat_pipeline_key {
471
+ int type;
472
+
473
+ bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
474
+ };
475
+
476
+ struct ggml_webgpu_repeat_pipeline_key_hash {
477
+ size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
478
+ size_t seed = 0;
479
+ ggml_webgpu_hash_combine(seed, key.type);
480
+ return seed;
481
+ }
482
+ };
483
+
484
+ /** Binary **/
485
+
486
+ struct ggml_webgpu_binary_pipeline_key {
487
+ int type;
488
+ int op;
489
+ bool inplace;
490
+ bool overlap;
491
+ bool src_overlap;
492
+
493
+ bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
494
+ return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
495
+ src_overlap == other.src_overlap;
496
+ }
497
+ };
498
+
499
+ struct ggml_webgpu_binary_pipeline_key_hash {
500
+ size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
501
+ size_t seed = 0;
502
+ ggml_webgpu_hash_combine(seed, key.type);
503
+ ggml_webgpu_hash_combine(seed, key.op);
504
+ ggml_webgpu_hash_combine(seed, key.inplace);
505
+ ggml_webgpu_hash_combine(seed, key.overlap);
506
+ ggml_webgpu_hash_combine(seed, key.src_overlap);
507
+ return seed;
508
+ }
509
+ };
510
+
511
+ /* Add_Id */
512
+
513
+ struct ggml_webgpu_add_id_pipeline_key {
514
+ bool inplace;
515
+
516
+ bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; }
517
+ };
518
+
519
+ struct ggml_webgpu_add_id_pipeline_key_hash {
520
+ size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const {
521
+ size_t seed = 0;
522
+ ggml_webgpu_hash_combine(seed, key.inplace);
523
+ return seed;
524
+ }
525
+ };
526
+
527
+ /** Unary **/
528
+
529
+ struct ggml_webgpu_unary_pipeline_key {
530
+ int type;
531
+ int op;
532
+ bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
533
+ bool inplace;
534
+ ggml_tri_type ttype; // only used for GGML_OP_TRI
535
+
536
+ bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
537
+ return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
538
+ ttype == other.ttype;
539
+ }
540
+ };
541
+
542
+ struct ggml_webgpu_unary_pipeline_key_hash {
543
+ size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
544
+ size_t seed = 0;
545
+ ggml_webgpu_hash_combine(seed, key.type);
546
+ ggml_webgpu_hash_combine(seed, key.op);
547
+ ggml_webgpu_hash_combine(seed, key.is_unary);
548
+ ggml_webgpu_hash_combine(seed, key.inplace);
549
+ ggml_webgpu_hash_combine(seed, key.ttype);
550
+ return seed;
551
+ }
552
+ };
553
+
554
+ /** FlashAttention */
555
+
556
+ struct ggml_webgpu_flash_attn_common_pipeline_key {
557
+ ggml_type q_type;
558
+ ggml_type k_type;
559
+ ggml_type v_type;
560
+ ggml_type dst_type;
19
561
  uint32_t head_dim_qk;
20
562
  uint32_t head_dim_v;
21
563
  bool kv_direct;
564
+ bool kv_overlap;
22
565
  bool has_mask;
23
566
  bool has_sinks;
24
567
  bool uses_logit_softcap;
25
- uint32_t sg_mat_m;
26
- uint32_t sg_mat_n;
27
- uint32_t sg_mat_k;
28
- size_t wg_mem_limit_bytes;
29
- uint32_t max_subgroup_size;
568
+
569
+ bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const {
570
+ return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type &&
571
+ dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
572
+ kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
573
+ has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap;
574
+ }
30
575
  };
31
576
 
32
- struct ggml_webgpu_flash_attn_shader_decisions {
33
- uint32_t q_tile = 0;
34
- uint32_t kv_tile = 0;
35
- uint32_t wg_size = 0;
577
+ inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed,
578
+ const ggml_webgpu_flash_attn_common_pipeline_key & key) {
579
+ ggml_webgpu_hash_combine(seed, key.q_type);
580
+ ggml_webgpu_hash_combine(seed, key.k_type);
581
+ ggml_webgpu_hash_combine(seed, key.v_type);
582
+ ggml_webgpu_hash_combine(seed, key.dst_type);
583
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
584
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
585
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
586
+ ggml_webgpu_hash_combine(seed, key.kv_overlap);
587
+ ggml_webgpu_hash_combine(seed, key.has_mask);
588
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
589
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
590
+ }
591
+
592
+ struct ggml_webgpu_flash_attn_vec_pipeline_key {
593
+ ggml_webgpu_flash_attn_common_pipeline_key common;
594
+
595
+ bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; }
36
596
  };
37
597
 
38
- struct ggml_webgpu_processed_shader {
39
- std::string wgsl;
40
- std::string variant;
41
- ggml_webgpu_flash_attn_shader_decisions decisions;
598
+ struct ggml_webgpu_flash_attn_vec_pipeline_key_hash {
599
+ size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const {
600
+ size_t seed = 0;
601
+ ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
602
+ return seed;
603
+ }
42
604
  };
43
605
 
44
- // This is exposed because it's necessary in supports_op
45
- inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
46
- uint32_t kv_tile,
47
- uint32_t head_dim_qk,
48
- uint32_t head_dim_v,
49
- bool has_mask,
50
- bool kv_direct) {
51
- const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
52
- size_t f16_elems = 0;
53
- size_t f32_elems = 0;
54
- f16_elems += q_tile * head_dim_qk; // q_shmem
55
- if (!kv_direct) {
56
- f16_elems += kv_tile * max_head_dim; // kv_shmem
606
+ struct ggml_webgpu_flash_attn_pipeline_key {
607
+ ggml_webgpu_flash_attn_common_pipeline_key common;
608
+ bool use_sg_matrix;
609
+
610
+ bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
611
+ return common == other.common && use_sg_matrix == other.use_sg_matrix;
57
612
  }
58
- f16_elems += q_tile * head_dim_v; // o_shmem
59
- if (has_mask) {
60
- f16_elems += q_tile * kv_tile; // mask_shmem
613
+ };
614
+
615
+ struct ggml_webgpu_flash_attn_pipeline_key_hash {
616
+ size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
617
+ size_t seed = 0;
618
+ ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
619
+ ggml_webgpu_hash_combine(seed, key.use_sg_matrix);
620
+ return seed;
61
621
  }
62
- f16_elems += q_tile * kv_tile; // inter_shmem
63
- f32_elems += q_tile; // row_max_shmem
64
- f32_elems += q_tile; // exp_sum_shmem
65
- return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
622
+ };
623
+
624
+ struct ggml_webgpu_flash_attn_vec_decisions {
625
+ uint32_t kv_tile = 0;
626
+ uint32_t wg_size = 0;
627
+ };
628
+
629
+ struct ggml_webgpu_flash_attn_decisions {
630
+ bool use_sg_matrix = false;
631
+ uint32_t q_tile = 0;
632
+ uint32_t kv_tile = 0;
633
+ uint32_t wg_size = 0;
634
+ };
635
+
636
+ inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
637
+ inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
638
+
639
+ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
640
+ constexpr uintptr_t ptr_base_addr = 0x1000u;
641
+ const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
642
+ return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
66
643
  }
67
644
 
68
- static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
69
- const size_t limit_bytes = context.wg_mem_limit_bytes;
70
- const size_t q_tile = context.sg_mat_m;
71
- const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
72
- 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
73
- size_t bytes_per_kv = 0;
74
- if (!context.kv_direct) {
75
- bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
76
- }
77
- if (context.has_mask) {
78
- bytes_per_kv += q_tile;
79
- }
80
- bytes_per_kv += q_tile;
81
- bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
82
- const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
83
- return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
645
+ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
646
+ const uint32_t offset_elems =
647
+ (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
648
+ ggml_type_size(K->type));
649
+ return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
650
+ }
651
+
652
+ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
653
+ const ggml_tensor * V,
654
+ size_t storage_offset_alignment) {
655
+ return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) &&
656
+ ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
657
+ }
658
+
659
+ inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
660
+ const ggml_tensor * K,
661
+ const ggml_tensor * V,
662
+ uint32_t kv_direct_align) {
663
+ return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
664
+ (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
665
+ }
666
+
667
+ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key(
668
+ const ggml_webgpu_shader_lib_context & context,
669
+ uint32_t kv_direct_align) {
670
+ ggml_webgpu_flash_attn_common_pipeline_key key = {};
671
+ key.q_type = context.src0->type;
672
+ key.k_type = context.src1->type;
673
+ key.v_type = context.src2->type;
674
+ key.dst_type = context.dst->type;
675
+ key.head_dim_qk = (uint32_t) context.src0->ne[0];
676
+ key.head_dim_v = (uint32_t) context.src2->ne[0];
677
+ key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
678
+ key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
679
+ key.has_mask = context.src3 != nullptr;
680
+ key.has_sinks = context.src4 != nullptr;
681
+ key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
682
+ return key;
84
683
  }
85
684
 
86
- inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
87
- pre_wgsl::Preprocessor & preprocessor,
88
- const char * shader_src,
89
- const ggml_webgpu_flash_attn_shader_lib_context & context) {
685
+ inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines(
686
+ const ggml_webgpu_flash_attn_common_pipeline_key & key,
687
+ std::string & variant,
688
+ uint32_t q_tile,
689
+ uint32_t kv_tile,
690
+ uint32_t wg_size) {
90
691
  std::vector<std::string> defines;
91
- std::string variant = "flash_attn";
92
692
 
93
- switch (context.kv_type) {
693
+ switch (key.k_type) {
694
+ case GGML_TYPE_F32:
695
+ defines.push_back("K_F32");
696
+ break;
697
+ case GGML_TYPE_F16:
698
+ defines.push_back("K_F16");
699
+ break;
700
+ case GGML_TYPE_Q4_0:
701
+ defines.push_back("K_Q4_0");
702
+ break;
703
+ case GGML_TYPE_Q8_0:
704
+ defines.push_back("K_Q8_0");
705
+ break;
706
+ default:
707
+ GGML_ABORT("Unsupported K type for flash attention shader");
708
+ }
709
+ variant += std::string("_k") + ggml_type_name(key.k_type);
710
+
711
+ switch (key.v_type) {
94
712
  case GGML_TYPE_F32:
95
- defines.push_back("KV_F32");
713
+ defines.push_back("V_F32");
96
714
  break;
97
715
  case GGML_TYPE_F16:
98
- defines.push_back("KV_F16");
716
+ defines.push_back("V_F16");
99
717
  break;
100
718
  case GGML_TYPE_Q4_0:
101
- defines.push_back("KV_Q4_0");
719
+ defines.push_back("V_Q4_0");
102
720
  break;
103
721
  case GGML_TYPE_Q8_0:
104
- defines.push_back("KV_Q8_0");
722
+ defines.push_back("V_Q8_0");
723
+ break;
724
+ default:
725
+ GGML_ABORT("Unsupported V type for flash attention shader");
726
+ }
727
+ variant += std::string("_v") + ggml_type_name(key.v_type);
728
+
729
+ switch (key.q_type) {
730
+ case GGML_TYPE_F32:
731
+ defines.push_back("Q_F32");
732
+ break;
733
+ case GGML_TYPE_F16:
734
+ defines.push_back("Q_F16");
735
+ break;
736
+ default:
737
+ GGML_ABORT("Unsupported Q type for flash attention shader");
738
+ }
739
+ variant += std::string("_q") + ggml_type_name(key.q_type);
740
+
741
+ switch (key.dst_type) {
742
+ case GGML_TYPE_F32:
743
+ defines.push_back("DST_F32");
744
+ break;
745
+ case GGML_TYPE_F16:
746
+ defines.push_back("DST_F16");
105
747
  break;
106
748
  default:
107
- GGML_ABORT("Unsupported KV type for flash attention shader");
749
+ GGML_ABORT("Unsupported dst type for flash attention shader");
108
750
  }
109
- variant += std::string("_") + ggml_type_name(context.kv_type);
751
+ variant += std::string("_dst") + ggml_type_name(key.dst_type);
110
752
 
111
- if (context.has_mask) {
753
+ if (key.has_mask) {
112
754
  defines.push_back("MASK");
113
755
  variant += "_mask";
114
756
  }
115
- if (context.has_sinks) {
757
+ if (key.has_sinks) {
116
758
  defines.push_back("SINKS");
117
759
  variant += "_sinks";
118
760
  }
119
- if (context.uses_logit_softcap) {
761
+ if (key.uses_logit_softcap) {
120
762
  defines.push_back("LOGIT_SOFTCAP");
121
763
  variant += "_lgsc";
122
764
  }
123
-
124
- if (context.kv_direct) {
765
+ if (key.kv_direct) {
125
766
  defines.push_back("KV_DIRECT");
126
767
  variant += "_kvdirect";
127
768
  }
769
+ if (key.kv_overlap) {
770
+ defines.push_back("KV_OVERLAP");
771
+ variant += "_kv_overlap";
772
+ }
128
773
 
129
- defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
130
- variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
774
+ defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
775
+ variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
131
776
 
132
- defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
133
- variant += std::string("_hsv") + std::to_string(context.head_dim_v);
777
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
778
+ variant += std::string("_hsv") + std::to_string(key.head_dim_v);
134
779
 
135
- // For now these are not part of the variant name
136
- defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
137
- defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
138
- defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
780
+ defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
781
+ defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
782
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
139
783
 
140
- // Add chosen Q/KV tile sizes
141
- uint32_t q_tile = context.sg_mat_m;
142
- uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
143
- context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
144
- if (context.kv_direct) {
145
- GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
146
- // Avoids having to use bounds-checks and decreasing performance for direct KV loads
147
- while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
148
- kv_tile -= context.sg_mat_n;
784
+ if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) {
785
+ defines.push_back("U32_DEQUANT_HELPERS");
786
+ }
787
+
788
+ return defines;
789
+ }
790
+
791
+ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
792
+ uint32_t head_dim_v;
793
+ uint32_t wg_size;
794
+ ggml_type dst_type;
795
+ };
796
+
797
+ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
798
+ size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
799
+ size_t seed = 0;
800
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
801
+ ggml_webgpu_hash_combine(seed, key.wg_size);
802
+ ggml_webgpu_hash_combine(seed, key.dst_type);
803
+ return seed;
804
+ }
805
+ };
806
+
807
+ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
808
+ const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
809
+ return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type;
810
+ }
811
+
812
+ struct ggml_webgpu_flash_attn_blk_pipeline_key {
813
+ uint32_t kv_tile;
814
+
815
+ bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; }
816
+ };
817
+
818
+ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
819
+ size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
820
+ size_t seed = 0;
821
+ ggml_webgpu_hash_combine(seed, key.kv_tile);
822
+ return seed;
823
+ }
824
+ };
825
+
826
+ // Note: this will slightly overestimate memory usage for vec path
827
+ // since row_max and exp_sum shmem are not needed.
828
+ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
829
+ uint32_t kv_tile,
830
+ uint32_t head_dim_qk,
831
+ uint32_t head_dim_v,
832
+ bool has_mask,
833
+ bool kv_direct) {
834
+ const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
835
+ size_t f16_elems = 0;
836
+ size_t f32_elems = 0;
837
+
838
+ f32_elems += q_tile * head_dim_qk; // q_shmem
839
+ if (!kv_direct) {
840
+ f32_elems += kv_tile * max_head_dim; // kv_shmem
841
+ }
842
+ f32_elems += q_tile * head_dim_v; // o_shmem
843
+ if (has_mask) {
844
+ f32_elems += q_tile * kv_tile; // mask_shmem
845
+ }
846
+ f32_elems += q_tile * kv_tile; // inter_shmem
847
+ f32_elems += q_tile; // row_max_shmem
848
+ f32_elems += q_tile; // exp_sum_shmem
849
+ return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
850
+ }
851
+
852
+ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes,
853
+ uint32_t q_tile,
854
+ uint32_t kv_granularity,
855
+ uint32_t head_dim_qk,
856
+ uint32_t head_dim_v,
857
+ bool has_mask,
858
+ bool kv_direct) {
859
+ const size_t base_q_bytes =
860
+ ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct);
861
+ if (limit_bytes <= base_q_bytes) {
862
+ return 0;
863
+ }
864
+ const size_t one_kv_bytes =
865
+ ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct);
866
+ const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
867
+ if (bytes_per_kv == 0) {
868
+ return 0;
869
+ }
870
+ const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
871
+ return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
872
+ }
873
+
874
+ inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes,
875
+ uint32_t head_dim_qk,
876
+ uint32_t head_dim_v,
877
+ bool has_mask,
878
+ bool kv_direct) {
879
+ const uint32_t max_kv_tile =
880
+ ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct);
881
+ GGML_ASSERT(max_kv_tile > 0);
882
+
883
+ uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile);
884
+ if (kv_direct) {
885
+ kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
886
+ while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
887
+ kv_tile -= 1u;
149
888
  }
150
889
  }
151
890
 
152
- defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
153
- defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
891
+ return kv_tile;
892
+ }
154
893
 
155
- // workgroup size
156
- uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
894
+ inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix,
895
+ uint32_t sg_mat_k,
896
+ uint32_t sg_mat_n,
897
+ const ggml_tensor * Q,
898
+ const ggml_tensor * V) {
899
+ return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0;
900
+ }
157
901
 
158
- defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
902
+ /** Matrix Multiplication **/
903
+
904
+ struct ggml_webgpu_mul_mat_vec_pipeline_key {
905
+ ggml_type src0_type;
906
+ ggml_type src1_type;
907
+ int vectorized;
908
+ bool use_mmvq;
909
+
910
+ bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
911
+ return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
912
+ use_mmvq == other.use_mmvq;
913
+ }
914
+ };
915
+
916
+ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
917
+ size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const {
918
+ size_t seed = 0;
919
+ ggml_webgpu_hash_combine(seed, key.src0_type);
920
+ ggml_webgpu_hash_combine(seed, key.src1_type);
921
+ ggml_webgpu_hash_combine(seed, key.vectorized);
922
+ ggml_webgpu_hash_combine(seed, key.use_mmvq);
923
+ return seed;
924
+ }
925
+ };
926
+
927
+ struct ggml_webgpu_mul_mat_vec_shader_decisions {
928
+ uint32_t wg_size;
929
+ uint32_t outputs_per_wg;
930
+ uint32_t vec_size;
931
+ };
932
+
933
+ struct ggml_webgpu_quantize_q8_pipeline_key {
934
+ ggml_type src0_type;
935
+
936
+ bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
937
+ };
938
+
939
+ struct ggml_webgpu_quantize_q8_pipeline_key_hash {
940
+ size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
941
+ size_t seed = 0;
942
+ ggml_webgpu_hash_combine(seed, key.src0_type);
943
+ return seed;
944
+ }
945
+ };
946
+
947
+ struct ggml_webgpu_mul_mat_pipeline_key {
948
+ ggml_type src0_type;
949
+ ggml_type src1_type;
950
+ int vectorized;
951
+ int use_subgroup_matrix;
159
952
 
160
- ggml_webgpu_processed_shader result;
161
- result.wgsl = preprocessor.preprocess(shader_src, defines);
162
- result.variant = variant;
163
- result.decisions.q_tile = q_tile;
164
- result.decisions.kv_tile = kv_tile;
165
- result.decisions.wg_size = wg_size;
166
- return result;
953
+ bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
954
+ return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
955
+ use_subgroup_matrix == other.use_subgroup_matrix;
956
+ }
957
+ };
958
+
959
+ struct ggml_webgpu_mul_mat_pipeline_key_hash {
960
+ size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
961
+ size_t seed = 0;
962
+ ggml_webgpu_hash_combine(seed, key.src0_type);
963
+ ggml_webgpu_hash_combine(seed, key.src1_type);
964
+ ggml_webgpu_hash_combine(seed, key.vectorized);
965
+ ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
966
+ return seed;
967
+ }
968
+ };
969
+
970
+ struct ggml_webgpu_mul_mat_shader_decisions {
971
+ uint32_t tile_k;
972
+ uint32_t wg_size_m;
973
+ uint32_t wg_size_n;
974
+ uint32_t wg_size;
975
+ uint32_t outputs_per_wg;
976
+ int use_subgroup_matrix;
977
+
978
+ uint32_t tile_m;
979
+ uint32_t tile_n;
980
+
981
+ // Subgroup matrix parameters
982
+ uint32_t subgroup_m;
983
+ uint32_t subgroup_n;
984
+ uint32_t subgroup_matrix_m;
985
+ uint32_t subgroup_matrix_n;
986
+
987
+ uint32_t mul_mat_wg_size;
988
+ };
989
+
990
+ /** MUL_MAT_ID **/
991
+
992
+ struct ggml_webgpu_mul_mat_id_pipeline_key {
993
+ ggml_type src0_type;
994
+ ggml_type src1_type;
995
+ uint32_t n_experts;
996
+ int vectorized;
997
+
998
+ bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
999
+ return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
1000
+ vectorized == other.vectorized;
1001
+ }
1002
+ };
1003
+
1004
+ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
1005
+ size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const {
1006
+ size_t seed = 0;
1007
+ ggml_webgpu_hash_combine(seed, key.src0_type);
1008
+ ggml_webgpu_hash_combine(seed, key.src1_type);
1009
+ ggml_webgpu_hash_combine(seed, key.n_experts);
1010
+ ggml_webgpu_hash_combine(seed, key.vectorized);
1011
+ return seed;
1012
+ }
1013
+ };
1014
+
1015
+ /** Cpy **/
1016
+
1017
+ struct ggml_webgpu_cpy_pipeline_key {
1018
+ ggml_type src_type;
1019
+ ggml_type dst_type;
1020
+
1021
+ bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
1022
+ return src_type == other.src_type && dst_type == other.dst_type;
1023
+ }
1024
+ };
1025
+
1026
+ struct ggml_webgpu_cpy_pipeline_key_hash {
1027
+ size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
1028
+ size_t seed = 0;
1029
+ ggml_webgpu_hash_combine(seed, key.src_type);
1030
+ ggml_webgpu_hash_combine(seed, key.dst_type);
1031
+ return seed;
1032
+ }
1033
+ };
1034
+
1035
+ /** Glu **/
1036
+
1037
+ struct ggml_webgpu_glu_pipeline_key {
1038
+ ggml_glu_op glu_op;
1039
+ ggml_type type;
1040
+ bool split;
1041
+
1042
+ bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
1043
+ return glu_op == other.glu_op && type == other.type && split == other.split;
1044
+ }
1045
+ };
1046
+
1047
+ struct ggml_webgpu_glu_pipeline_key_hash {
1048
+ size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
1049
+ size_t seed = 0;
1050
+ ggml_webgpu_hash_combine(seed, key.glu_op);
1051
+ ggml_webgpu_hash_combine(seed, key.type);
1052
+ ggml_webgpu_hash_combine(seed, key.split);
1053
+ return seed;
1054
+ }
1055
+ };
1056
+
1057
+ /** Rope **/
1058
+
1059
+ struct ggml_webgpu_rope_pipeline_key {
1060
+ ggml_type type;
1061
+ bool inplace;
1062
+ bool has_ff;
1063
+
1064
+ bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
1065
+ return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
1066
+ }
1067
+ };
1068
+
1069
+ struct ggml_webgpu_rope_pipeline_key_hash {
1070
+ size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
1071
+ size_t seed = 0;
1072
+ ggml_webgpu_hash_combine(seed, key.type);
1073
+ ggml_webgpu_hash_combine(seed, key.inplace);
1074
+ ggml_webgpu_hash_combine(seed, key.has_ff);
1075
+ return seed;
1076
+ }
1077
+ };
1078
+
1079
+ /** SoftMax **/
1080
+
1081
+ struct ggml_webgpu_soft_max_pipeline_key {
1082
+ ggml_type mask_type;
1083
+ bool has_mask;
1084
+ bool has_sink;
1085
+ bool inplace;
1086
+
1087
+ bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
1088
+ return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
1089
+ inplace == other.inplace;
1090
+ }
1091
+ };
1092
+
1093
+ struct ggml_webgpu_soft_max_pipeline_key_hash {
1094
+ size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
1095
+ size_t seed = 0;
1096
+ ggml_webgpu_hash_combine(seed, key.mask_type);
1097
+ ggml_webgpu_hash_combine(seed, key.has_mask);
1098
+ ggml_webgpu_hash_combine(seed, key.has_sink);
1099
+ ggml_webgpu_hash_combine(seed, key.inplace);
1100
+ return seed;
1101
+ }
1102
+ };
1103
+
1104
+ /** MMVQ **/
1105
+
1106
+ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
1107
+ const ggml_tensor * src1,
1108
+ bool supports_dot_product,
1109
+ const std::string & vendor) {
1110
+ if (src1->ne[1] == 1) {
1111
+ bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
1112
+ if (supports_dp4a && supports_dot_product) {
1113
+ switch (src1->type) {
1114
+ case GGML_TYPE_F32:
1115
+ switch (src0->type) {
1116
+ case GGML_TYPE_Q4_0:
1117
+ case GGML_TYPE_Q4_1:
1118
+ case GGML_TYPE_Q8_0:
1119
+ case GGML_TYPE_Q2_K:
1120
+ case GGML_TYPE_Q4_K:
1121
+ return src0->ne[0] % 4 == 0;
1122
+ default:
1123
+ break;
1124
+ }
1125
+ break;
1126
+ default:
1127
+ break;
1128
+ }
1129
+ }
1130
+ }
1131
+ return false;
167
1132
  }
168
1133
 
1134
+ class ggml_webgpu_shader_lib {
1135
+ wgpu::Device device;
1136
+ pre_wgsl::Preprocessor preprocessor;
1137
+
1138
+ std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
1139
+ std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
1140
+ std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
1141
+ std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
1142
+ std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
1143
+ std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
1144
+ row_norm_pipelines; // op/inplace
1145
+
1146
+ std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
1147
+ get_rows_pipelines; // src_type, vectorized
1148
+ std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
1149
+ unary_pipelines; // type/op/inplace
1150
+ std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
1151
+ scale_pipelines; // inplace
1152
+ std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
1153
+ solve_tri_pipelines; // type
1154
+ std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
1155
+ ssm_conv_pipelines; // type/vectorized
1156
+ std::unordered_map<ggml_webgpu_ssm_scan_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_scan_pipeline_key_hash>
1157
+ ssm_scan_pipelines; // type/d_state
1158
+ std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
1159
+ webgpu_pipeline,
1160
+ ggml_webgpu_gated_delta_net_pipeline_key_hash>
1161
+ gated_delta_net_pipelines; // type/S_v/kda
1162
+ std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
1163
+ pad_pipelines; // circular/non-circular
1164
+ std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
1165
+ binary_pipelines; // type/op/inplace/overlap/src_overlap
1166
+ std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash>
1167
+ add_id_pipelines; // inplace
1168
+ std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
1169
+ concat_pipelines; // type
1170
+ std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
1171
+ repeat_pipelines; // type
1172
+ std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key,
1173
+ webgpu_pipeline,
1174
+ ggml_webgpu_flash_attn_vec_pipeline_key_hash>
1175
+ flash_attn_vec_pipelines;
1176
+ std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
1177
+ flash_attn_pipelines;
1178
+ std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
1179
+ webgpu_pipeline,
1180
+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
1181
+ flash_attn_vec_reduce_pipelines;
1182
+ std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
1183
+ webgpu_pipeline,
1184
+ ggml_webgpu_flash_attn_blk_pipeline_key_hash>
1185
+ flash_attn_blk_pipelines;
1186
+ std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
1187
+ mul_mat_vec_pipelines; // fast mat-vec (n==1)
1188
+ std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
1189
+ mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
1190
+ std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
1191
+ quantize_q8_pipelines;
1192
+ std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
1193
+ std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
1194
+ mul_mat_id_pipelines; // src0_type/src1_type
1195
+ std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
1196
+ mul_mat_id_vec_pipelines; // src0_type/src1_type
1197
+
1198
+ std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
1199
+ set_rows_pipelines;
1200
+ std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
1201
+ std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
1202
+ std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
1203
+ std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
1204
+ rope_pipelines;
1205
+ std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
1206
+ soft_max_pipelines;
1207
+ std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
1208
+ conv2d_pipelines;
1209
+ std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash>
1210
+ im2col_pipelines;
1211
+
1212
+ std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
1213
+ webgpu_pipeline,
1214
+ ggml_webgpu_rms_norm_mul_pipeline_key_hash>
1215
+ rms_norm_mul_pipelines;
1216
+ std::unordered_map<ggml_webgpu_upscale_pipeline_key, webgpu_pipeline, ggml_webgpu_upscale_pipeline_key_hash>
1217
+ upscale_pipelines;
1218
+
1219
+ public:
1220
+ ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
1221
+
1222
+ webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
1223
+ auto it = sum_rows_pipelines.find(1);
1224
+ if (it != sum_rows_pipelines.end()) {
1225
+ return it->second;
1226
+ }
1227
+ std::vector<std::string> defines;
1228
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1229
+
1230
+ auto processed = preprocessor.preprocess(wgsl_sum_rows, defines);
1231
+ sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
1232
+ return sum_rows_pipelines[1];
1233
+ }
1234
+
1235
+ webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
1236
+ ggml_webgpu_row_norm_pipeline_key key = {};
1237
+ key.op = context.dst->op;
1238
+ key.src_type = context.src0->type;
1239
+ key.dst_type = context.dst->type;
1240
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
1241
+
1242
+ auto it = row_norm_pipelines.find(key);
1243
+ if (it != row_norm_pipelines.end()) {
1244
+ return it->second;
1245
+ }
1246
+ std::vector<std::string> defines;
1247
+ std::string variant;
1248
+
1249
+ switch (key.op) {
1250
+ case GGML_OP_RMS_NORM:
1251
+ defines.push_back("RMS_NORM");
1252
+ variant = "rms_norm";
1253
+ break;
1254
+ case GGML_OP_NORM:
1255
+ defines.push_back("NORM");
1256
+ variant = "norm";
1257
+ break;
1258
+ case GGML_OP_L2_NORM:
1259
+ defines.push_back("L2_NORM");
1260
+ variant = "l2_norm";
1261
+ break;
1262
+ default:
1263
+ GGML_ABORT("Unsupported op for row_norm shader");
1264
+ }
1265
+
1266
+ if (key.inplace) {
1267
+ defines.push_back("INPLACE");
1268
+ variant += "_inplace";
1269
+ }
1270
+
1271
+ if (key.src_type == GGML_TYPE_F32) {
1272
+ defines.push_back("SRC_F32");
1273
+ variant += "_src_f32";
1274
+ } else if (key.src_type == GGML_TYPE_F16) {
1275
+ defines.push_back("SRC_F16");
1276
+ variant += "_src_f16";
1277
+ }
1278
+
1279
+ if (key.dst_type == GGML_TYPE_F32) {
1280
+ defines.push_back("DST_F32");
1281
+ variant += "_dst_f32";
1282
+ } else if (key.dst_type == GGML_TYPE_F16) {
1283
+ defines.push_back("DST_F16");
1284
+ variant += "_dst_f16";
1285
+ }
1286
+
1287
+ const uint32_t row_norm_wg_size = 128u;
1288
+ uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
1289
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1290
+ auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
1291
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1292
+ decisions->wg_size = wg_size;
1293
+ decisions->inplace = key.inplace;
1294
+ row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
1295
+ row_norm_pipelines[key].context = decisions;
1296
+ return row_norm_pipelines[key];
1297
+ }
1298
+
1299
+ webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
1300
+ bool vec4 = context.src0->ne[0] % 4 == 0;
1301
+
1302
+ auto it = argmax_pipelines.find(vec4);
1303
+ if (it != argmax_pipelines.end()) {
1304
+ return it->second;
1305
+ }
1306
+ std::string variant = "argmax";
1307
+ std::vector<std::string> defines;
1308
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1309
+ if (vec4) {
1310
+ defines.push_back("VEC4");
1311
+ variant += "_vec4";
1312
+ }
1313
+
1314
+ auto processed = preprocessor.preprocess(wgsl_argmax, defines);
1315
+ argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
1316
+ return argmax_pipelines.at(vec4);
1317
+ }
1318
+
1319
+ webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
1320
+ const bool quantized = ggml_is_quantized(context.dst->type);
1321
+ ggml_webgpu_set_rows_pipeline_key key = {};
1322
+ key.dst_type = context.dst->type;
1323
+ key.vec4 =
1324
+ (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0;
1325
+ key.i64_idx = context.src1->type == GGML_TYPE_I64;
1326
+ key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0);
1327
+
1328
+ auto it = set_rows_pipelines.find(key);
1329
+ if (it != set_rows_pipelines.end()) {
1330
+ return it->second;
1331
+ }
1332
+
1333
+ std::vector<std::string> defines;
1334
+ std::string variant = "set_rows";
1335
+
1336
+ switch (context.dst->type) {
1337
+ case GGML_TYPE_F32:
1338
+ defines.push_back("DST_F32");
1339
+ variant += "_dstf32";
1340
+ break;
1341
+ case GGML_TYPE_F16:
1342
+ defines.push_back("DST_F16");
1343
+ variant += "_dstf16";
1344
+ break;
1345
+ case GGML_TYPE_Q8_0:
1346
+ defines.push_back("DST_Q8_0");
1347
+ variant += "_dstq8_0";
1348
+ break;
1349
+ case GGML_TYPE_Q4_0:
1350
+ defines.push_back("DST_Q4_0");
1351
+ variant += "_dstq4_0";
1352
+ break;
1353
+ default:
1354
+ GGML_ABORT("Unsupported dst type for set_rows shader");
1355
+ }
1356
+
1357
+ if (key.vec4) {
1358
+ defines.push_back("VEC4");
1359
+ variant += "_vec4";
1360
+ }
1361
+ if (key.i64_idx) {
1362
+ defines.push_back("I64_IDX");
1363
+ variant += "_i64idx";
1364
+ }
1365
+ if (key.pair_blocks) {
1366
+ defines.push_back("PAIR_BLOCKS");
1367
+ variant += "_pair_blocks";
1368
+ }
1369
+
1370
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1371
+
1372
+ const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows;
1373
+ auto processed = preprocessor.preprocess(shader_source, defines);
1374
+ auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
1375
+ decisions->vec4 = key.vec4;
1376
+ decisions->i64_idx = key.i64_idx;
1377
+ decisions->pair_blocks = key.pair_blocks;
1378
+ decisions->wg_size = context.max_wg_size;
1379
+ set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
1380
+ set_rows_pipelines[key].context = decisions;
1381
+ return set_rows_pipelines[key];
1382
+ }
1383
+
1384
+ webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
1385
+ ggml_webgpu_set_pipeline_key key = {};
1386
+ key.type = context.dst->type;
1387
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
1388
+
1389
+ auto it = set_pipelines.find(key);
1390
+ if (it != set_pipelines.end()) {
1391
+ return it->second;
1392
+ }
1393
+
1394
+ std::vector<std::string> defines;
1395
+ std::string variant = "set";
1396
+
1397
+ switch (key.type) {
1398
+ case GGML_TYPE_F32:
1399
+ defines.push_back("TYPE_F32");
1400
+ variant += "_f32";
1401
+ break;
1402
+ case GGML_TYPE_I32:
1403
+ defines.push_back("TYPE_I32");
1404
+ variant += "_i32";
1405
+ break;
1406
+ default:
1407
+ GGML_ABORT("Unsupported type for set shader");
1408
+ }
1409
+
1410
+ if (key.inplace) {
1411
+ defines.push_back("INPLACE");
1412
+ variant += "_inplace";
1413
+ }
1414
+
1415
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1416
+
1417
+ auto processed = preprocessor.preprocess(wgsl_set, defines);
1418
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1419
+ decisions->wg_size = context.max_wg_size;
1420
+ decisions->inplace = key.inplace;
1421
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1422
+ pipeline.context = decisions;
1423
+ set_pipelines[key] = pipeline;
1424
+ return set_pipelines[key];
1425
+ }
1426
+
1427
+ webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
1428
+ auto it = cumsum_pipelines.find(1);
1429
+ if (it != cumsum_pipelines.end()) {
1430
+ return it->second;
1431
+ }
1432
+
1433
+ std::vector<std::string> defines;
1434
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1435
+
1436
+ auto processed = preprocessor.preprocess(wgsl_cumsum, defines);
1437
+ cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
1438
+ return cumsum_pipelines[1];
1439
+ }
1440
+
1441
+ webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
1442
+ bool is_top_k = context.dst->op == GGML_OP_TOP_K;
1443
+ // ascending order is 0, descending order is 1
1444
+ const int32_t order =
1445
+ is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
1446
+
1447
+ auto it = argsort_pipelines.find(order);
1448
+ if (it != argsort_pipelines.end()) {
1449
+ return it->second;
1450
+ }
1451
+
1452
+ std::vector<std::string> defines;
1453
+ std::string variant = "argsort";
1454
+ defines.push_back(std::string("ORDER=") + std::to_string(order));
1455
+ variant += std::string("_order") + std::to_string(order);
1456
+ uint32_t wg_size = 1;
1457
+ while (wg_size * 2 <= context.max_wg_size &&
1458
+ wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
1459
+ wg_size *= 2;
1460
+ }
1461
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1462
+ auto processed = preprocessor.preprocess(wgsl_argsort, defines);
1463
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1464
+ decisions->wg_size = wg_size;
1465
+ argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
1466
+ argsort_pipelines[order].context = decisions;
1467
+ return argsort_pipelines[order];
1468
+ }
1469
+
1470
+ webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
1471
+ bool is_top_k = context.dst->op == GGML_OP_TOP_K;
1472
+ // ascending order is 0, descending order is 1
1473
+ const int32_t order =
1474
+ is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
1475
+
1476
+ auto it = argsort_merge_pipelines.find(order);
1477
+ if (it != argsort_merge_pipelines.end()) {
1478
+ return it->second;
1479
+ }
1480
+
1481
+ std::vector<std::string> defines;
1482
+ std::string variant = "argsort_merge";
1483
+ defines.push_back(std::string("ORDER=") + std::to_string(order));
1484
+ variant += std::string("_order") + std::to_string(order);
1485
+ uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
1486
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1487
+
1488
+ auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines);
1489
+ argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
1490
+ return argsort_merge_pipelines[order];
1491
+ }
1492
+
1493
+ webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
1494
+ const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
1495
+ ggml_webgpu_get_rows_pipeline_key key = {};
1496
+ key.src_type = context.src0->type;
1497
+ key.vectorized = (int) vectorized;
1498
+
1499
+ auto it = get_rows_pipelines.find(key);
1500
+ if (it != get_rows_pipelines.end()) {
1501
+ return it->second;
1502
+ }
1503
+
1504
+ std::vector<std::string> defines;
1505
+ std::string variant = "get_rows";
1506
+
1507
+ const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
1508
+ const char * type_str = type_traits->type_name;
1509
+
1510
+ switch (key.src_type) {
1511
+ case GGML_TYPE_F32:
1512
+ defines.push_back("FLOAT_PARALLEL");
1513
+ if (key.vectorized) {
1514
+ defines.push_back("F32_VEC");
1515
+ defines.push_back("SRC_TYPE=vec4<f32>");
1516
+ defines.push_back("DST_TYPE=vec4<f32>");
1517
+ defines.push_back("BLOCK_SIZE=4u");
1518
+ } else {
1519
+ defines.push_back("F32");
1520
+ defines.push_back("SRC_TYPE=f32");
1521
+ defines.push_back("DST_TYPE=f32");
1522
+ defines.push_back("BLOCK_SIZE=1u");
1523
+ }
1524
+ variant += "_f32";
1525
+ break;
1526
+ case GGML_TYPE_F16:
1527
+ defines.push_back("FLOAT_PARALLEL");
1528
+ defines.push_back("F16");
1529
+ defines.push_back("SRC_TYPE=f16");
1530
+ defines.push_back("DST_TYPE=f32");
1531
+ defines.push_back("BLOCK_SIZE=1u");
1532
+ variant += "_f16";
1533
+ break;
1534
+ case GGML_TYPE_I32:
1535
+ defines.push_back("FLOAT_PARALLEL");
1536
+ defines.push_back("I32");
1537
+ defines.push_back("SRC_TYPE=i32");
1538
+ defines.push_back("DST_TYPE=i32");
1539
+ defines.push_back("BLOCK_SIZE=1u");
1540
+ variant += "_i32";
1541
+ break;
1542
+ default:
1543
+ {
1544
+ std::string type_upper = type_str;
1545
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
1546
+
1547
+ switch (key.src_type) {
1548
+ case GGML_TYPE_Q1_0:
1549
+ case GGML_TYPE_Q4_0:
1550
+ case GGML_TYPE_Q5_0:
1551
+ case GGML_TYPE_Q8_0:
1552
+ case GGML_TYPE_Q3_K:
1553
+ case GGML_TYPE_Q6_K:
1554
+ case GGML_TYPE_IQ2_XXS:
1555
+ case GGML_TYPE_IQ2_XS:
1556
+ case GGML_TYPE_IQ2_S:
1557
+ case GGML_TYPE_IQ3_XXS:
1558
+ case GGML_TYPE_IQ3_S:
1559
+ case GGML_TYPE_IQ1_S:
1560
+ case GGML_TYPE_IQ4_NL:
1561
+ case GGML_TYPE_MXFP4:
1562
+ {
1563
+ // Quantized types using u32 buffers for portability.
1564
+ defines.push_back("SRC_TYPE=u32");
1565
+ defines.push_back("U32_DEQUANT_HELPERS");
1566
+ break;
1567
+ }
1568
+ default:
1569
+ {
1570
+ defines.push_back(std::string("SRC_TYPE=") + type_str);
1571
+ }
1572
+ }
1573
+
1574
+ defines.push_back("BYTE_HELPERS");
1575
+ defines.push_back(type_upper + "_T");
1576
+ defines.push_back(type_upper);
1577
+ defines.push_back(type_upper + "_SCALE_MIN");
1578
+ defines.push_back(type_upper + "_TABLES");
1579
+ defines.push_back(type_upper + "_GRID");
1580
+ defines.push_back(type_upper + "_LUT");
1581
+
1582
+ variant += "_";
1583
+ variant += type_str;
1584
+
1585
+ defines.push_back("DST_TYPE=f32");
1586
+
1587
+ if (key.src_type == GGML_TYPE_Q1_0) {
1588
+ defines.push_back("BLOCK_SIZE=128u");
1589
+ } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
1590
+ key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
1591
+ defines.push_back("BLOCK_SIZE=32u");
1592
+ } else if (key.src_type >= GGML_TYPE_Q2_K) {
1593
+ defines.push_back("BLOCK_SIZE=256u");
1594
+ } else {
1595
+ defines.push_back("BLOCK_SIZE=1u");
1596
+ }
1597
+ break;
1598
+ }
1599
+ }
1600
+
1601
+ if (key.vectorized) {
1602
+ variant += "_vec";
1603
+ }
1604
+
1605
+ defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
1606
+
1607
+ auto processed = preprocessor.preprocess(wgsl_get_rows, defines);
1608
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1609
+ decisions->wg_size = context.max_wg_size;
1610
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1611
+ pipeline.context = decisions;
1612
+ get_rows_pipelines[key] = pipeline;
1613
+ return get_rows_pipelines[key];
1614
+ }
1615
+
1616
+ webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
1617
+ ggml_webgpu_scale_pipeline_key key = {};
1618
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
1619
+
1620
+ auto it = scale_pipelines.find(key);
1621
+ if (it != scale_pipelines.end()) {
1622
+ return it->second;
1623
+ }
1624
+
1625
+ std::vector<std::string> defines;
1626
+ std::string variant = "scale";
1627
+
1628
+ if (key.inplace) {
1629
+ defines.push_back("INPLACE");
1630
+ variant += "_inplace";
1631
+ }
1632
+
1633
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1634
+
1635
+ auto processed = preprocessor.preprocess(wgsl_scale, defines);
1636
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1637
+ decisions->wg_size = context.max_wg_size;
1638
+ decisions->inplace = key.inplace;
1639
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1640
+ pipeline.context = decisions;
1641
+ scale_pipelines[key] = pipeline;
1642
+ return scale_pipelines[key];
1643
+ }
1644
+
1645
+ webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
1646
+ ggml_webgpu_solve_tri_pipeline_key key = {};
1647
+ key.type = context.dst->type;
1648
+ key.n = (int) context.src0->ne[0];
1649
+ key.k = (int) context.src1->ne[0];
1650
+
1651
+ auto it = solve_tri_pipelines.find(key);
1652
+ if (it != solve_tri_pipelines.end()) {
1653
+ return it->second;
1654
+ }
1655
+
1656
+ std::vector<std::string> defines;
1657
+ std::string variant = "solve_tri";
1658
+
1659
+ switch (key.type) {
1660
+ case GGML_TYPE_F32:
1661
+ variant += "_f32";
1662
+ break;
1663
+ default:
1664
+ GGML_ABORT("Unsupported type for solve_tri shader");
1665
+ }
1666
+
1667
+ const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
1668
+ const uint32_t k_tile = wg_size;
1669
+ const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
1670
+ const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
1671
+
1672
+ defines.push_back(std::string("N=") + std::to_string(key.n));
1673
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1674
+ defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
1675
+ defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
1676
+
1677
+ auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
1678
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1679
+ decisions->wg_size = wg_size;
1680
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1681
+ pipeline.context = decisions;
1682
+ solve_tri_pipelines[key] = pipeline;
1683
+ return solve_tri_pipelines[key];
1684
+ }
1685
+
1686
+ webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
1687
+ ggml_webgpu_ssm_conv_pipeline_key key = {};
1688
+ key.type = context.dst->type;
1689
+ key.vectorized = context.src1->ne[0] == 4;
1690
+
1691
+ auto it = ssm_conv_pipelines.find(key);
1692
+ if (it != ssm_conv_pipelines.end()) {
1693
+ return it->second;
1694
+ }
1695
+
1696
+ std::vector<std::string> defines;
1697
+ std::string variant = "ssm_conv";
1698
+
1699
+ switch (key.type) {
1700
+ case GGML_TYPE_F32:
1701
+ variant += "_f32";
1702
+ break;
1703
+ default:
1704
+ GGML_ABORT("Unsupported type for ssm_conv shader");
1705
+ }
1706
+
1707
+ if (key.vectorized) {
1708
+ defines.push_back("VECTORIZED");
1709
+ variant += "_vec4";
1710
+ }
1711
+
1712
+ constexpr uint32_t block_size = 32u;
1713
+ constexpr uint32_t tokens_per_wg = 8u;
1714
+
1715
+ defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
1716
+ defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
1717
+
1718
+ auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
1719
+ auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
1720
+ decisions->block_size = block_size;
1721
+ decisions->tokens_per_wg = tokens_per_wg;
1722
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1723
+ pipeline.context = decisions;
1724
+ ssm_conv_pipelines[key] = pipeline;
1725
+ return ssm_conv_pipelines[key];
1726
+ }
1727
+
1728
+ webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) {
1729
+ ggml_webgpu_ssm_scan_pipeline_key key = {};
1730
+ key.type = context.dst->type;
1731
+ key.d_state = (int) context.src0->ne[0];
1732
+ key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
1733
+ ggml_webgpu_tensor_overlap(context.src1, context.src5);
1734
+
1735
+ auto it = ssm_scan_pipelines.find(key);
1736
+ if (it != ssm_scan_pipelines.end()) {
1737
+ return it->second;
1738
+ }
1739
+
1740
+ std::vector<std::string> defines;
1741
+ std::string variant = "ssm_scan";
1742
+
1743
+ switch (key.type) {
1744
+ case GGML_TYPE_F32:
1745
+ variant += "_f32";
1746
+ break;
1747
+ default:
1748
+ GGML_ABORT("Unsupported type for ssm_scan shader");
1749
+ }
1750
+
1751
+ const uint32_t wg_size = (uint32_t) key.d_state;
1752
+
1753
+ constexpr uint32_t tokens_per_tile = 4u;
1754
+
1755
+ defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u");
1756
+ defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u");
1757
+
1758
+ if (context.supports_subgroups) {
1759
+ defines.push_back("USE_SUBGROUP_REDUCTION");
1760
+ variant += "_sg_reduce";
1761
+ } else {
1762
+ variant += "_wg_reduce";
1763
+ }
1764
+
1765
+ if (key.xbc_overlap) {
1766
+ defines.push_back("XBC_OVERLAP");
1767
+ }
1768
+
1769
+ variant += "_d" + std::to_string(key.d_state);
1770
+
1771
+ auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
1772
+ auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
1773
+ decisions->wg_size = wg_size;
1774
+ decisions->tokens_per_tile = tokens_per_tile;
1775
+ decisions->xbc_overlap = key.xbc_overlap;
1776
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1777
+ pipeline.context = decisions;
1778
+ ssm_scan_pipelines[key] = pipeline;
1779
+ return ssm_scan_pipelines[key];
1780
+ }
1781
+
1782
+ webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
1783
+ ggml_webgpu_gated_delta_net_pipeline_key key = {};
1784
+ key.type = context.dst->type;
1785
+ key.s_v = (int) context.src2->ne[0];
1786
+ key.kda = context.src3->ne[0] == context.src2->ne[0];
1787
+
1788
+ auto it = gated_delta_net_pipelines.find(key);
1789
+ if (it != gated_delta_net_pipelines.end()) {
1790
+ return it->second;
1791
+ }
1792
+
1793
+ std::vector<std::string> defines;
1794
+ std::string variant = "gated_delta_net";
1795
+
1796
+ switch (key.type) {
1797
+ case GGML_TYPE_F32:
1798
+ variant += "_f32";
1799
+ break;
1800
+ default:
1801
+ GGML_ABORT("Unsupported type for gated_delta_net shader");
1802
+ }
1803
+
1804
+ if (key.kda) {
1805
+ defines.push_back("KDA");
1806
+ variant += "_kda";
1807
+ }
1808
+
1809
+ defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
1810
+ defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
1811
+
1812
+ auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
1813
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1814
+ gated_delta_net_pipelines[key] = pipeline;
1815
+ return gated_delta_net_pipelines[key];
1816
+ }
1817
+
1818
+ webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
1819
+ ggml_webgpu_pad_pipeline_key key = {};
1820
+ key.circular = ggml_get_op_params_i32(context.dst, 8) != 0;
1821
+
1822
+ auto it = pad_pipelines.find(key);
1823
+ if (it != pad_pipelines.end()) {
1824
+ return it->second;
1825
+ }
1826
+
1827
+ std::vector<std::string> defines;
1828
+ std::string variant = "pad";
1829
+
1830
+ if (key.circular) {
1831
+ defines.push_back("CIRCULAR");
1832
+ variant += "_circular";
1833
+ }
1834
+
1835
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1836
+
1837
+ auto processed = preprocessor.preprocess(wgsl_pad, defines);
1838
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1839
+ decisions->wg_size = context.max_wg_size;
1840
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1841
+ pipeline.context = decisions;
1842
+ pad_pipelines[key] = pipeline;
1843
+ return pad_pipelines[key];
1844
+ }
1845
+
1846
+ webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
1847
+ ggml_webgpu_quantize_q8_pipeline_key key = {};
1848
+ key.src0_type = context.src0->type;
1849
+
1850
+ auto it = quantize_q8_pipelines.find(key);
1851
+ if (it != quantize_q8_pipelines.end()) {
1852
+ return it->second;
1853
+ }
1854
+ const char * shader_src = wgsl_quantize_q8;
1855
+ std::vector<std::string> defines;
1856
+ std::string variant = "quantize_q8";
1857
+
1858
+ uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
1859
+
1860
+ defines.push_back("SRC1_INNER_TYPE=f32");
1861
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1862
+
1863
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
1864
+ std::string src0_name = src0_traits->type_name;
1865
+ std::string type_upper = src0_name;
1866
+ variant += "_" + src0_name;
1867
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
1868
+
1869
+ defines.push_back("MUL_ACC_" + type_upper);
1870
+ defines.push_back("Q8_1_T");
1871
+
1872
+ defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
1873
+ variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
1874
+
1875
+ auto processed = preprocessor.preprocess(shader_src, defines);
1876
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1877
+ decisions->wg_size = wg_size;
1878
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1879
+ pipeline.context = decisions;
1880
+ quantize_q8_pipelines[key] = pipeline;
1881
+ return quantize_q8_pipelines[key];
1882
+ }
1883
+
1884
+ webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
1885
+ ggml_webgpu_mul_mat_vec_pipeline_key key = {};
1886
+ key.src0_type = context.src0->type;
1887
+ key.src1_type = context.src1->type;
1888
+ key.vectorized = (context.src0->ne[0] % 4 == 0 &&
1889
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1890
+ 1 :
1891
+ 0;
1892
+ key.use_mmvq =
1893
+ ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
1894
+
1895
+ auto it = mul_mat_vec_pipelines.find(key);
1896
+ if (it != mul_mat_vec_pipelines.end()) {
1897
+ return it->second;
1898
+ }
1899
+
1900
+ std::vector<std::string> defines;
1901
+ std::string variant = "mul_mat_vec";
1902
+ const char * shader_src = wgsl_mul_mat_vec;
1903
+
1904
+ // src0 type (matrix row)
1905
+ switch (context.src0->type) {
1906
+ case GGML_TYPE_F32:
1907
+ defines.push_back("SRC0_INNER_TYPE=f32");
1908
+ defines.push_back("MUL_ACC_FLOAT");
1909
+ variant += "_f32";
1910
+ break;
1911
+ case GGML_TYPE_F16:
1912
+ defines.push_back("SRC0_INNER_TYPE=f16");
1913
+ defines.push_back("MUL_ACC_FLOAT");
1914
+ variant += "_f16";
1915
+ break;
1916
+ default:
1917
+ {
1918
+ // Quantized types: use helpers but accumulate in f16
1919
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
1920
+ std::string src0_name = src0_traits->type_name;
1921
+ std::string type_upper = src0_name;
1922
+ variant += "_" + src0_name;
1923
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
1924
+
1925
+ defines.push_back("BYTE_HELPERS");
1926
+ defines.push_back("MUL_ACC_" + type_upper);
1927
+ defines.push_back("U32_DEQUANT_HELPERS");
1928
+ defines.push_back("SRC0_INNER_TYPE=u32");
1929
+ switch (context.src0->type) {
1930
+ case GGML_TYPE_Q8_0:
1931
+ case GGML_TYPE_Q4_0:
1932
+ case GGML_TYPE_Q4_1:
1933
+ if (key.use_mmvq) {
1934
+ defines.push_back("LEGACY_QUANTS");
1935
+ }
1936
+ break;
1937
+ case GGML_TYPE_Q2_K:
1938
+ case GGML_TYPE_Q4_K:
1939
+ if (key.use_mmvq) {
1940
+ defines.push_back("K_QUANTS");
1941
+ }
1942
+ break;
1943
+ case GGML_TYPE_IQ1_S:
1944
+ case GGML_TYPE_IQ1_M:
1945
+ case GGML_TYPE_IQ2_S:
1946
+ case GGML_TYPE_IQ3_S:
1947
+ case GGML_TYPE_IQ4_NL:
1948
+ case GGML_TYPE_IQ4_XS:
1949
+ defines.push_back(type_upper + "_GRID");
1950
+ break;
1951
+ case GGML_TYPE_IQ2_XXS:
1952
+ case GGML_TYPE_IQ2_XS:
1953
+ case GGML_TYPE_IQ3_XXS:
1954
+ defines.push_back(type_upper + "_GRID");
1955
+ defines.push_back(type_upper + "_TABLES");
1956
+ break;
1957
+ case GGML_TYPE_MXFP4:
1958
+ defines.push_back(type_upper + "_LUT");
1959
+ break;
1960
+ default:
1961
+ break;
1962
+ }
1963
+ break;
1964
+ }
1965
+ }
1966
+
1967
+ // src1 type (vector)
1968
+ switch (context.src1->type) {
1969
+ case GGML_TYPE_F32:
1970
+ defines.push_back("SRC1_INNER_TYPE=f32");
1971
+ variant += "_f32";
1972
+ break;
1973
+ case GGML_TYPE_F16:
1974
+ defines.push_back("SRC1_INNER_TYPE=f16");
1975
+ variant += "_f16";
1976
+ break;
1977
+ default:
1978
+ GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
1979
+ }
1980
+
1981
+ // VEC/SCALAR controls
1982
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
1983
+
1984
+ uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
1985
+ uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
1986
+
1987
+ if (key.src0_type == GGML_TYPE_Q1_0) {
1988
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
1989
+ } else if (key.src0_type >= GGML_TYPE_Q2_K) {
1990
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
1991
+ } else if (key.src0_type >= GGML_TYPE_Q4_0) {
1992
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
1993
+ }
1994
+
1995
+ if (key.use_mmvq) {
1996
+ defines.push_back("MMVQ");
1997
+ defines.push_back("Q8_1_T");
1998
+ }
1999
+
2000
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
2001
+ defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
2002
+ defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
2003
+ variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
2004
+ if (key.vectorized) {
2005
+ variant += "_vectorized";
2006
+ }
2007
+
2008
+ auto processed = preprocessor.preprocess(shader_src, defines);
2009
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
2010
+ decisions->wg_size = wg_size;
2011
+ decisions->outputs_per_wg = outputs_per_wg;
2012
+ decisions->vec_size = key.vectorized ? 4 : 1;
2013
+
2014
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2015
+ pipeline.context = decisions;
2016
+ mul_mat_vec_pipelines[key] = pipeline;
2017
+ return mul_mat_vec_pipelines[key];
2018
+ }
2019
+
2020
+ webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
2021
+ ggml_webgpu_mul_mat_pipeline_key key = {};
2022
+ key.src0_type = context.src0->type;
2023
+ key.src1_type = context.src1->type;
2024
+ key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
2025
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2026
+ 1 :
2027
+ 0;
2028
+ key.use_subgroup_matrix = context.supports_subgroup_matrix;
2029
+
2030
+ auto it = mul_mat_fast_pipelines.find(key);
2031
+ if (it != mul_mat_fast_pipelines.end()) {
2032
+ return it->second;
2033
+ }
2034
+
2035
+ const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile;
2036
+ std::vector<std::string> defines;
2037
+ std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile";
2038
+
2039
+ // src1 type
2040
+ switch (context.src1->type) {
2041
+ case GGML_TYPE_F32:
2042
+ defines.push_back("SRC1_INNER_TYPE=f32");
2043
+ break;
2044
+ case GGML_TYPE_F16:
2045
+ defines.push_back("SRC1_INNER_TYPE=f16");
2046
+ break;
2047
+ default:
2048
+ GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
2049
+ }
2050
+
2051
+ // src0 type
2052
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
2053
+ const char * src0_name = src0_traits->type_name;
2054
+
2055
+ switch (context.src0->type) {
2056
+ case GGML_TYPE_F32:
2057
+ defines.push_back("SRC0_INNER_TYPE=f32");
2058
+ defines.push_back("FLOAT");
2059
+ defines.push_back("MUL_ACC_FLOAT");
2060
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
2061
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
2062
+ variant += "_f32";
2063
+ break;
2064
+ case GGML_TYPE_F16:
2065
+ defines.push_back("SRC0_INNER_TYPE=f16");
2066
+ defines.push_back("FLOAT");
2067
+ defines.push_back("MUL_ACC_FLOAT");
2068
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
2069
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
2070
+ variant += "_f16";
2071
+ break;
2072
+ default:
2073
+ {
2074
+ std::string type_upper = src0_name;
2075
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
2076
+
2077
+ defines.push_back("BYTE_HELPERS");
2078
+ defines.push_back("MUL_ACC_" + type_upper);
2079
+ defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
2080
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
2081
+ defines.push_back("U32_DEQUANT_HELPERS");
2082
+ defines.push_back("SRC0_INNER_TYPE=u32");
2083
+
2084
+ switch (context.src0->type) {
2085
+ case GGML_TYPE_IQ1_S:
2086
+ case GGML_TYPE_IQ1_M:
2087
+ case GGML_TYPE_IQ4_NL:
2088
+ case GGML_TYPE_IQ4_XS:
2089
+ defines.push_back(type_upper + "_GRID");
2090
+ break;
2091
+ case GGML_TYPE_IQ2_XXS:
2092
+ case GGML_TYPE_IQ2_XS:
2093
+ case GGML_TYPE_IQ2_S:
2094
+ case GGML_TYPE_IQ3_XXS:
2095
+ case GGML_TYPE_IQ3_S:
2096
+ defines.push_back(type_upper + "_GRID");
2097
+ defines.push_back(type_upper + "_TABLES");
2098
+ break;
2099
+ case GGML_TYPE_MXFP4:
2100
+ defines.push_back(type_upper + "_LUT");
2101
+ break;
2102
+ default:
2103
+ break;
2104
+ }
2105
+
2106
+ variant += std::string("_") + src0_name;
2107
+ break;
2108
+ }
2109
+ }
2110
+
2111
+ // VEC/SCALAR controls
2112
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
2113
+
2114
+ const bool is_quant = ggml_is_quantized(context.src0->type);
2115
+
2116
+ uint32_t tile_k;
2117
+ if (key.use_subgroup_matrix) {
2118
+ tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
2119
+ } else {
2120
+ tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
2121
+ }
2122
+
2123
+ // Tiles
2124
+ defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
2125
+ defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
2126
+
2127
+ // Subgroup matrix specifics
2128
+ if (key.use_subgroup_matrix) {
2129
+ defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
2130
+ defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
2131
+ defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
2132
+ defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
2133
+ defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
2134
+ defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
2135
+ defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
2136
+ defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
2137
+ defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
2138
+ }
2139
+
2140
+ // variant suffix for src1 type
2141
+ variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
2142
+ if (key.vectorized) {
2143
+ variant += "_vectorized";
2144
+ }
2145
+
2146
+ if (!key.use_subgroup_matrix) {
2147
+ defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
2148
+ defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
2149
+ defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
2150
+ }
2151
+
2152
+ auto processed = preprocessor.preprocess(shader_src, defines);
2153
+
2154
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
2155
+ decisions->tile_k = tile_k;
2156
+ decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
2157
+ decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
2158
+ decisions->use_subgroup_matrix = key.use_subgroup_matrix;
2159
+ if (key.use_subgroup_matrix) {
2160
+ decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
2161
+ decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
2162
+ decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
2163
+ decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
2164
+ decisions->wg_size = context.max_subgroup_size;
2165
+ } else {
2166
+ decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
2167
+ decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
2168
+ decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
2169
+ decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
2170
+ }
2171
+
2172
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2173
+ pipeline.context = decisions;
2174
+ mul_mat_fast_pipelines[key] = pipeline;
2175
+ return mul_mat_fast_pipelines[key];
2176
+ }
2177
+
2178
+ webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
2179
+ auto it = mul_mat_id_gather_pipelines.find(1);
2180
+ if (it != mul_mat_id_gather_pipelines.end()) {
2181
+ return it->second;
2182
+ }
2183
+ std::vector<std::string> defines;
2184
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2185
+
2186
+ auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines);
2187
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2188
+ decisions->wg_size = context.max_wg_size;
2189
+
2190
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather");
2191
+ pipeline.context = decisions;
2192
+ mul_mat_id_gather_pipelines[1] = pipeline;
2193
+ return pipeline;
2194
+ }
2195
+
2196
+ webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
2197
+ ggml_webgpu_mul_mat_id_pipeline_key key = {};
2198
+ key.src0_type = context.src0->type;
2199
+ key.src1_type = context.src1->type;
2200
+ key.n_experts = context.src0->ne[2];
2201
+ key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
2202
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2203
+ 1 :
2204
+ 0;
2205
+
2206
+ auto it = mul_mat_id_pipelines.find(key);
2207
+ if (it != mul_mat_id_pipelines.end()) {
2208
+ return it->second;
2209
+ }
2210
+
2211
+ std::vector<std::string> defines;
2212
+ std::string variant = "mul_mat_id";
2213
+ defines.push_back("MUL_MAT_ID");
2214
+
2215
+ // src1 type
2216
+ switch (context.src1->type) {
2217
+ case GGML_TYPE_F32:
2218
+ defines.push_back("SRC1_INNER_TYPE=f32");
2219
+ break;
2220
+ case GGML_TYPE_F16:
2221
+ defines.push_back("SRC1_INNER_TYPE=f16");
2222
+ break;
2223
+ default:
2224
+ GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
2225
+ }
2226
+
2227
+ // src0 type
2228
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
2229
+ const char * src0_name = src0_traits->type_name;
2230
+
2231
+ switch (context.src0->type) {
2232
+ case GGML_TYPE_F32:
2233
+ defines.push_back("SRC0_INNER_TYPE=f32");
2234
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
2235
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
2236
+ variant += "_f32";
2237
+ break;
2238
+ case GGML_TYPE_F16:
2239
+ defines.push_back("SRC0_INNER_TYPE=f16");
2240
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
2241
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
2242
+ variant += "_f16";
2243
+ break;
2244
+ default:
2245
+ {
2246
+ std::string type_upper = src0_name;
2247
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
2248
+
2249
+ defines.push_back("BYTE_HELPERS");
2250
+ defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
2251
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
2252
+ defines.push_back("U32_DEQUANT_HELPERS");
2253
+ defines.push_back("SRC0_INNER_TYPE=u32");
2254
+
2255
+ switch (context.src0->type) {
2256
+ case GGML_TYPE_IQ1_S:
2257
+ case GGML_TYPE_IQ1_M:
2258
+ case GGML_TYPE_IQ4_NL:
2259
+ case GGML_TYPE_IQ4_XS:
2260
+ defines.push_back(type_upper + "_GRID");
2261
+ break;
2262
+ case GGML_TYPE_IQ2_XXS:
2263
+ case GGML_TYPE_IQ2_XS:
2264
+ case GGML_TYPE_IQ2_S:
2265
+ case GGML_TYPE_IQ3_XXS:
2266
+ case GGML_TYPE_IQ3_S:
2267
+ defines.push_back(type_upper + "_GRID");
2268
+ defines.push_back(type_upper + "_TABLES");
2269
+ break;
2270
+ case GGML_TYPE_MXFP4:
2271
+ defines.push_back(type_upper + "_LUT");
2272
+ break;
2273
+ default:
2274
+ break;
2275
+ }
2276
+
2277
+ variant += std::string("_") + src0_name;
2278
+ break;
2279
+ }
2280
+ }
2281
+
2282
+ // VEC/SCALAR controls
2283
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
2284
+
2285
+ // mul_mat_id is register-tile only.
2286
+ const uint32_t tile_k =
2287
+ ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
2288
+
2289
+ // Tiles
2290
+ defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
2291
+ defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
2292
+ defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
2293
+
2294
+ defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
2295
+ defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
2296
+
2297
+ // variant suffix for src1 type
2298
+ variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
2299
+ if (key.vectorized) {
2300
+ variant += "_vectorized";
2301
+ }
2302
+
2303
+ auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
2304
+
2305
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
2306
+ decisions->tile_k = tile_k;
2307
+ decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
2308
+ decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
2309
+ decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
2310
+ decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
2311
+ decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
2312
+
2313
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2314
+ pipeline.context = decisions;
2315
+ mul_mat_id_pipelines[key] = pipeline;
2316
+ return mul_mat_id_pipelines[key];
2317
+ }
2318
+
2319
+ webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
2320
+ ggml_webgpu_mul_mat_id_pipeline_key key = {};
2321
+ key.src0_type = context.src0->type;
2322
+ key.src1_type = context.src1->type;
2323
+ key.n_experts = context.src0->ne[2];
2324
+ key.vectorized = (context.src0->ne[0] % 4 == 0 &&
2325
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2326
+ 1 :
2327
+ 0;
2328
+
2329
+ auto it = mul_mat_id_vec_pipelines.find(key);
2330
+ if (it != mul_mat_id_vec_pipelines.end()) {
2331
+ return it->second;
2332
+ }
2333
+
2334
+ std::vector<std::string> defines;
2335
+ std::string variant = "mul_mat_id_vec";
2336
+ const char * shader_src = wgsl_mul_mat_id_vec;
2337
+
2338
+ // src1 type
2339
+ switch (context.src1->type) {
2340
+ case GGML_TYPE_F32:
2341
+ defines.push_back("SRC1_INNER_TYPE=f32");
2342
+ break;
2343
+ case GGML_TYPE_F16:
2344
+ defines.push_back("SRC1_INNER_TYPE=f16");
2345
+ break;
2346
+ default:
2347
+ GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
2348
+ }
2349
+
2350
+ // src0 type
2351
+ switch (context.src0->type) {
2352
+ case GGML_TYPE_F32:
2353
+ defines.push_back("SRC0_INNER_TYPE=f32");
2354
+ defines.push_back("MUL_ACC_FLOAT");
2355
+ variant += "_f32";
2356
+ break;
2357
+ case GGML_TYPE_F16:
2358
+ defines.push_back("SRC0_INNER_TYPE=f16");
2359
+ defines.push_back("MUL_ACC_FLOAT");
2360
+ variant += "_f16";
2361
+ break;
2362
+ default:
2363
+ {
2364
+ // Quantized types: use helpers but accumulate in f16
2365
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
2366
+ std::string src0_name = src0_traits->type_name;
2367
+ std::string type_upper = src0_name;
2368
+ variant += "_" + src0_name;
2369
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
2370
+
2371
+ defines.push_back("BYTE_HELPERS");
2372
+ defines.push_back("MUL_ACC_" + type_upper);
2373
+ defines.push_back("U32_DEQUANT_HELPERS");
2374
+ defines.push_back("SRC0_INNER_TYPE=u32");
2375
+ switch (context.src0->type) {
2376
+ case GGML_TYPE_IQ1_S:
2377
+ case GGML_TYPE_IQ1_M:
2378
+ case GGML_TYPE_IQ2_S:
2379
+ case GGML_TYPE_IQ3_S:
2380
+ case GGML_TYPE_IQ4_NL:
2381
+ case GGML_TYPE_IQ4_XS:
2382
+ defines.push_back(type_upper + "_GRID");
2383
+ break;
2384
+ case GGML_TYPE_IQ2_XXS:
2385
+ case GGML_TYPE_IQ2_XS:
2386
+ case GGML_TYPE_IQ3_XXS:
2387
+ defines.push_back(type_upper + "_GRID");
2388
+ defines.push_back(type_upper + "_TABLES");
2389
+ break;
2390
+ case GGML_TYPE_MXFP4:
2391
+ defines.push_back(type_upper + "_LUT");
2392
+ break;
2393
+ default:
2394
+ break;
2395
+ }
2396
+ break;
2397
+ }
2398
+ }
2399
+
2400
+ // VEC/SCALAR controls
2401
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
2402
+
2403
+ uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2404
+ uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
2405
+
2406
+ if (key.src0_type == GGML_TYPE_Q1_0) {
2407
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
2408
+ } else if (key.src0_type >= GGML_TYPE_Q2_K) {
2409
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
2410
+ } else if (key.src0_type >= GGML_TYPE_Q4_0) {
2411
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
2412
+ }
2413
+
2414
+ // variant suffix for src1 type
2415
+ variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
2416
+
2417
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
2418
+ defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
2419
+ defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
2420
+ variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
2421
+ if (key.vectorized) {
2422
+ variant += "_vectorized";
2423
+ }
2424
+
2425
+ defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
2426
+
2427
+ auto processed = preprocessor.preprocess(shader_src, defines);
2428
+
2429
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
2430
+ decisions->wg_size = wg_size;
2431
+ decisions->outputs_per_wg = outputs_per_wg;
2432
+
2433
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2434
+ pipeline.context = decisions;
2435
+ mul_mat_id_vec_pipelines[key] = pipeline;
2436
+ return mul_mat_id_vec_pipelines[key];
2437
+ }
2438
+
2439
+ webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
2440
+ const bool is_unary = context.dst->op == GGML_OP_UNARY;
2441
+ const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
2442
+ ggml_webgpu_unary_pipeline_key key = {};
2443
+ key.type = context.dst->type;
2444
+ key.op = op;
2445
+ key.is_unary = is_unary;
2446
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
2447
+ key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
2448
+
2449
+ auto it = unary_pipelines.find(key);
2450
+ if (it != unary_pipelines.end()) {
2451
+ return it->second;
2452
+ }
2453
+
2454
+ std::vector<std::string> defines;
2455
+ std::string variant =
2456
+ key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op);
2457
+ defines.push_back(variant);
2458
+
2459
+ switch (key.type) {
2460
+ case GGML_TYPE_F32:
2461
+ defines.push_back("TYPE_F32");
2462
+ variant += "_f32";
2463
+ break;
2464
+ case GGML_TYPE_F16:
2465
+ defines.push_back("TYPE_F16");
2466
+ variant += "_f16";
2467
+ break;
2468
+ default:
2469
+ GGML_ABORT("Unsupported type for unary shader");
2470
+ }
2471
+
2472
+ if (key.inplace) {
2473
+ defines.push_back("INPLACE");
2474
+ variant += "_inplace";
2475
+ }
2476
+
2477
+ if (op == GGML_OP_TRI) {
2478
+ switch (key.ttype) {
2479
+ case GGML_TRI_TYPE_LOWER:
2480
+ defines.push_back("TRI_TYPE_LOWER");
2481
+ variant += "_tri_type_lower";
2482
+ break;
2483
+ case GGML_TRI_TYPE_LOWER_DIAG:
2484
+ defines.push_back("TRI_TYPE_LOWER_DIAG");
2485
+ variant += "_tri_type_lower_diag";
2486
+ break;
2487
+ case GGML_TRI_TYPE_UPPER:
2488
+ defines.push_back("TRI_TYPE_UPPER");
2489
+ variant += "_tri_type_upper";
2490
+ break;
2491
+ case GGML_TRI_TYPE_UPPER_DIAG:
2492
+ defines.push_back("TRI_TYPE_UPPER_DIAG");
2493
+ variant += "_tri_upper_diag";
2494
+ break;
2495
+ default:
2496
+ GGML_ABORT("Unsupported ggml_tri_type for unary shader");
2497
+ }
2498
+ }
2499
+
2500
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2501
+
2502
+ auto processed = preprocessor.preprocess(wgsl_unary, defines);
2503
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2504
+ decisions->wg_size = context.max_wg_size;
2505
+ decisions->inplace = key.inplace;
2506
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2507
+ pipeline.context = decisions;
2508
+ unary_pipelines[key] = pipeline;
2509
+ return unary_pipelines[key];
2510
+ }
2511
+
2512
+ webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
2513
+ ggml_webgpu_rms_norm_mul_pipeline_key key = {};
2514
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
2515
+ key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
2516
+ key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
2517
+
2518
+ auto it = rms_norm_mul_pipelines.find(key);
2519
+ if (it != rms_norm_mul_pipelines.end()) {
2520
+ return it->second;
2521
+ }
2522
+
2523
+ std::vector<std::string> defines;
2524
+ std::string op_name = "RMS_NORM_MUL";
2525
+ std::string variant = op_name;
2526
+
2527
+ if (key.inplace) {
2528
+ defines.push_back("INPLACE");
2529
+ variant += "_inplace";
2530
+ } else if (key.overlap) {
2531
+ defines.push_back("OVERLAP");
2532
+ variant += "_overlap";
2533
+ } else if (key.src_overlap) {
2534
+ defines.push_back("SRC_OVERLAP");
2535
+ variant += "_src_overlap";
2536
+ }
2537
+
2538
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2539
+
2540
+ auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
2541
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
2542
+ pipeline_decisions->wg_size = context.max_wg_size;
2543
+ pipeline_decisions->inplace = key.inplace;
2544
+ pipeline_decisions->overlap = key.overlap;
2545
+ pipeline_decisions->src_overlap = key.src_overlap;
2546
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2547
+ pipeline.context = pipeline_decisions;
2548
+ rms_norm_mul_pipelines[key] = pipeline;
2549
+ return rms_norm_mul_pipelines[key];
2550
+ }
2551
+
2552
+ webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
2553
+ ggml_webgpu_binary_pipeline_key key = {};
2554
+ key.type = context.dst->type;
2555
+ key.op = context.dst->op;
2556
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
2557
+ key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
2558
+ key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
2559
+
2560
+ auto it = binary_pipelines.find(key);
2561
+ if (it != binary_pipelines.end()) {
2562
+ return it->second;
2563
+ }
2564
+
2565
+ std::vector<std::string> defines;
2566
+ std::string op_name = ggml_op_name((ggml_op) key.op);
2567
+ std::string variant = op_name;
2568
+
2569
+ defines.push_back(std::string("OP_") + op_name);
2570
+
2571
+ switch (key.type) {
2572
+ case GGML_TYPE_F32:
2573
+ defines.push_back("TYPE_F32");
2574
+ variant += "_f32";
2575
+ break;
2576
+ case GGML_TYPE_F16:
2577
+ defines.push_back("TYPE_F16");
2578
+ variant += "_f16";
2579
+ break;
2580
+ default:
2581
+ GGML_ABORT("Unsupported type for binary shader");
2582
+ }
2583
+
2584
+ if (key.inplace) {
2585
+ defines.push_back("INPLACE");
2586
+ variant += "_inplace";
2587
+ } else if (key.overlap) {
2588
+ defines.push_back("OVERLAP");
2589
+ variant += "_overlap";
2590
+ } else if (key.src_overlap) {
2591
+ defines.push_back("SRC_OVERLAP");
2592
+ variant += "_src_overlap";
2593
+ }
2594
+
2595
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2596
+
2597
+ auto processed = preprocessor.preprocess(wgsl_binary, defines);
2598
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
2599
+ pipeline_decisions->wg_size = context.max_wg_size;
2600
+ pipeline_decisions->inplace = key.inplace;
2601
+ pipeline_decisions->overlap = key.overlap;
2602
+ pipeline_decisions->src_overlap = key.src_overlap;
2603
+
2604
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2605
+ pipeline.context = pipeline_decisions;
2606
+ binary_pipelines[key] = pipeline;
2607
+ return binary_pipelines[key];
2608
+ }
2609
+
2610
+ webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
2611
+ ggml_webgpu_add_id_pipeline_key key = {};
2612
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
2613
+
2614
+ auto it = add_id_pipelines.find(key);
2615
+ if (it != add_id_pipelines.end()) {
2616
+ return it->second;
2617
+ }
2618
+
2619
+ std::vector<std::string> defines;
2620
+ std::string variant = "add_id";
2621
+ const char * shader_src = wgsl_add_id;
2622
+
2623
+ if (key.inplace) {
2624
+ defines.push_back("INPLACE");
2625
+ variant += "_inplace";
2626
+ }
2627
+
2628
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2629
+
2630
+ auto processed = preprocessor.preprocess(shader_src, defines);
2631
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2632
+ pipeline_decisions->wg_size = context.max_wg_size;
2633
+ pipeline_decisions->inplace = key.inplace;
2634
+
2635
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2636
+ pipeline.context = pipeline_decisions;
2637
+ add_id_pipelines[key] = pipeline;
2638
+ return pipeline;
2639
+ }
2640
+
2641
+ webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
2642
+ ggml_webgpu_concat_pipeline_key key = {};
2643
+ key.type = context.dst->type;
2644
+ key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
2645
+
2646
+ auto it = concat_pipelines.find(key);
2647
+ if (it != concat_pipelines.end()) {
2648
+ return it->second;
2649
+ }
2650
+
2651
+ std::vector<std::string> defines;
2652
+ std::string variant = "concat";
2653
+
2654
+ switch (key.type) {
2655
+ case GGML_TYPE_F32:
2656
+ defines.push_back("TYPE_F32");
2657
+ variant += "_f32";
2658
+ break;
2659
+ case GGML_TYPE_I32:
2660
+ defines.push_back("TYPE_I32");
2661
+ variant += "_i32";
2662
+ break;
2663
+ default:
2664
+ GGML_ABORT("Unsupported type for concat shader");
2665
+ }
2666
+
2667
+ if (key.src_overlap) {
2668
+ defines.push_back("SRC_OVERLAP");
2669
+ variant += "_src_overlap";
2670
+ }
2671
+
2672
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2673
+
2674
+ auto processed = preprocessor.preprocess(wgsl_concat, defines);
2675
+ auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
2676
+ decisions->wg_size = context.max_wg_size;
2677
+ decisions->src_overlap = key.src_overlap;
2678
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2679
+ pipeline.context = decisions;
2680
+ concat_pipelines[key] = pipeline;
2681
+ return concat_pipelines[key];
2682
+ }
2683
+
2684
+ webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
2685
+ ggml_webgpu_repeat_pipeline_key key = {};
2686
+ key.type = context.dst->type;
2687
+
2688
+ auto it = repeat_pipelines.find(key);
2689
+ if (it != repeat_pipelines.end()) {
2690
+ return it->second;
2691
+ }
2692
+
2693
+ std::vector<std::string> defines;
2694
+ std::string variant = "repeat";
2695
+
2696
+ switch (key.type) {
2697
+ case GGML_TYPE_F32:
2698
+ defines.push_back("TYPE_F32");
2699
+ variant += "_f32";
2700
+ break;
2701
+ case GGML_TYPE_I32:
2702
+ defines.push_back("TYPE_I32");
2703
+ variant += "_i32";
2704
+ break;
2705
+ case GGML_TYPE_I16:
2706
+ defines.push_back("TYPE_I16");
2707
+ variant += "_i16";
2708
+ break;
2709
+ default:
2710
+ GGML_ABORT("Unsupported type for repeat shader");
2711
+ }
2712
+
2713
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2714
+
2715
+ auto processed = preprocessor.preprocess(wgsl_repeat, defines);
2716
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2717
+ decisions->wg_size = context.max_wg_size;
2718
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2719
+ pipeline.context = decisions;
2720
+ repeat_pipelines[key] = pipeline;
2721
+ return repeat_pipelines[key];
2722
+ }
2723
+
2724
+ webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
2725
+ const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
2726
+ context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2);
2727
+ ggml_webgpu_flash_attn_decisions decisions = {};
2728
+ decisions.use_sg_matrix = can_use_subgroup_matrix;
2729
+ decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
2730
+
2731
+ ggml_webgpu_flash_attn_pipeline_key key = {};
2732
+ key.common =
2733
+ ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u);
2734
+ key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct;
2735
+ key.use_sg_matrix = decisions.use_sg_matrix;
2736
+
2737
+ const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
2738
+ context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u,
2739
+ key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
2740
+ GGML_ASSERT(max_kv_tile > 0);
2741
+
2742
+ decisions.kv_tile = decisions.use_sg_matrix ?
2743
+ std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) :
2744
+ std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile);
2745
+ decisions.wg_size =
2746
+ decisions.use_sg_matrix ?
2747
+ std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) :
2748
+ std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
2749
+ GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size));
2750
+
2751
+ if (key.common.kv_direct) {
2752
+ decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
2753
+ while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
2754
+ decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size;
2755
+ }
2756
+ }
2757
+
2758
+ auto it = flash_attn_pipelines.find(key);
2759
+ if (it != flash_attn_pipelines.end()) {
2760
+ return it->second;
2761
+ }
2762
+
2763
+ std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile";
2764
+ std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile,
2765
+ decisions.kv_tile, decisions.wg_size);
2766
+ const char * shader_src = nullptr;
2767
+ if (!key.use_sg_matrix) {
2768
+ shader_src = wgsl_flash_attn_tile;
2769
+ defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
2770
+ defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
2771
+ variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
2772
+ std::to_string(context.max_subgroup_size);
2773
+ } else {
2774
+ shader_src = wgsl_flash_attn;
2775
+ defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
2776
+ defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
2777
+ defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
2778
+ }
2779
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
2780
+ webgpu_pipeline pipeline =
2781
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
2782
+ pipeline.context = pipeline_decisions;
2783
+ flash_attn_pipelines[key] = pipeline;
2784
+ return flash_attn_pipelines[key];
2785
+ }
2786
+
2787
+ webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
2788
+ ggml_webgpu_flash_attn_vec_pipeline_key key = {};
2789
+ key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
2790
+
2791
+ auto it = flash_attn_vec_pipelines.find(key);
2792
+ if (it != flash_attn_vec_pipelines.end()) {
2793
+ return it->second;
2794
+ }
2795
+
2796
+ ggml_webgpu_flash_attn_vec_decisions decisions = {};
2797
+ decisions.kv_tile =
2798
+ ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk,
2799
+ key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
2800
+ decisions.wg_size = context.max_subgroup_size;
2801
+
2802
+ std::string variant = "flash_attn_vec";
2803
+ std::vector<std::string> defines =
2804
+ ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size);
2805
+ if (key.common.has_mask) {
2806
+ defines.push_back("BLK");
2807
+ variant.resize(variant.size() - (sizeof("_mask") - 1));
2808
+ variant += "_mask_blk";
2809
+ }
2810
+ uint32_t vec_ne = 1u;
2811
+ if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 &&
2812
+ key.common.head_dim_qk == key.common.head_dim_v) {
2813
+ switch (key.common.head_dim_qk) {
2814
+ case 64:
2815
+ case 192:
2816
+ case 576:
2817
+ vec_ne = 2u;
2818
+ break;
2819
+ case 96:
2820
+ vec_ne = 4u;
2821
+ break;
2822
+ default:
2823
+ break;
2824
+ }
2825
+ }
2826
+ defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
2827
+
2828
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions);
2829
+ webgpu_pipeline pipeline =
2830
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
2831
+ pipeline.context = pipeline_decisions;
2832
+ flash_attn_vec_pipelines[key] = pipeline;
2833
+ return flash_attn_vec_pipelines[key];
2834
+ }
2835
+
2836
+ webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
2837
+ ggml_webgpu_flash_attn_blk_pipeline_key key = {};
2838
+ key.kv_tile = kv_tile;
2839
+ auto it = flash_attn_blk_pipelines.find(key);
2840
+ if (it != flash_attn_blk_pipelines.end()) {
2841
+ return it->second;
2842
+ }
2843
+
2844
+ std::vector<std::string> defines;
2845
+ std::string variant = "flash_attn_vec_blk";
2846
+
2847
+ defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile));
2848
+ variant += std::string("_kvt") + std::to_string(key.kv_tile);
2849
+
2850
+ uint32_t wg_size = 1;
2851
+ while ((wg_size << 1) <= context.max_wg_size) {
2852
+ wg_size <<= 1;
2853
+ }
2854
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
2855
+ variant += std::string("_wg") + std::to_string(wg_size);
2856
+
2857
+ webgpu_pipeline pipeline =
2858
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant);
2859
+ flash_attn_blk_pipelines[key] = pipeline;
2860
+ return flash_attn_blk_pipelines[key];
2861
+ }
2862
+
2863
+ webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
2864
+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
2865
+ key.head_dim_v = (uint32_t) context.src2->ne[0];
2866
+ key.dst_type = context.dst->type;
2867
+ key.wg_size = context.max_wg_size;
2868
+ auto it = flash_attn_vec_reduce_pipelines.find(key);
2869
+ if (it != flash_attn_vec_reduce_pipelines.end()) {
2870
+ return it->second;
2871
+ }
2872
+
2873
+ std::vector<std::string> defines;
2874
+ std::string variant = "flash_attn_vec_reduce";
2875
+
2876
+ switch (key.dst_type) {
2877
+ case GGML_TYPE_F32:
2878
+ defines.push_back("DST_F32");
2879
+ break;
2880
+ case GGML_TYPE_F16:
2881
+ defines.push_back("DST_F16");
2882
+ break;
2883
+ default:
2884
+ GGML_ABORT("Unsupported dst type for flash attention vec reduce shader");
2885
+ }
2886
+ variant += std::string("_dst") + ggml_type_name(key.dst_type);
2887
+
2888
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
2889
+ variant += std::string("_hsv") + std::to_string(key.head_dim_v);
2890
+
2891
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2892
+ variant += std::string("_wg") + std::to_string(context.max_wg_size);
2893
+
2894
+ webgpu_pipeline pipeline =
2895
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant);
2896
+ flash_attn_vec_reduce_pipelines[key] = pipeline;
2897
+ return flash_attn_vec_reduce_pipelines[key];
2898
+ }
2899
+
2900
+ webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
2901
+ ggml_webgpu_cpy_pipeline_key key = {};
2902
+ key.src_type = context.src0->type;
2903
+ key.dst_type = context.dst->type;
2904
+
2905
+ auto it = cpy_pipelines.find(key);
2906
+ if (it != cpy_pipelines.end()) {
2907
+ return it->second;
2908
+ }
2909
+
2910
+ std::vector<std::string> defines;
2911
+ std::string variant = "cpy";
2912
+
2913
+ switch (key.src_type) {
2914
+ case GGML_TYPE_F32:
2915
+ defines.push_back("SRC_F32");
2916
+ variant += "_f32";
2917
+ break;
2918
+ case GGML_TYPE_F16:
2919
+ defines.push_back("SRC_F16");
2920
+ variant += "_f16";
2921
+ break;
2922
+ default:
2923
+ GGML_ABORT("Unsupported src type for cpy shader");
2924
+ }
2925
+
2926
+ switch (key.dst_type) {
2927
+ case GGML_TYPE_F32:
2928
+ defines.push_back("DST_F32");
2929
+ variant += "_f32";
2930
+ break;
2931
+ case GGML_TYPE_F16:
2932
+ defines.push_back("DST_F16");
2933
+ variant += "_f16";
2934
+ break;
2935
+ case GGML_TYPE_I32:
2936
+ defines.push_back("DST_I32");
2937
+ variant += "_i32";
2938
+ break;
2939
+ default:
2940
+ GGML_ABORT("Unsupported dst type for cpy shader");
2941
+ }
2942
+
2943
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2944
+
2945
+ auto processed = preprocessor.preprocess(wgsl_cpy, defines);
2946
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2947
+ decisions->wg_size = context.max_wg_size;
2948
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2949
+ pipeline.context = decisions;
2950
+ cpy_pipelines[key] = pipeline;
2951
+ return cpy_pipelines[key];
2952
+ }
2953
+
2954
+ webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
2955
+ ggml_webgpu_glu_pipeline_key key = {};
2956
+ key.glu_op = ggml_get_glu_op(context.dst);
2957
+ key.type = context.dst->type;
2958
+ key.split = (context.src1 != nullptr);
2959
+
2960
+ auto it = glu_pipelines.find(key);
2961
+ if (it != glu_pipelines.end()) {
2962
+ return it->second;
2963
+ }
2964
+
2965
+ std::vector<std::string> defines;
2966
+ std::string variant = "glu";
2967
+
2968
+ switch (key.glu_op) {
2969
+ case GGML_GLU_OP_REGLU:
2970
+ defines.push_back("OP_REGLU");
2971
+ variant += "_reglu";
2972
+ break;
2973
+ case GGML_GLU_OP_GEGLU:
2974
+ defines.push_back("OP_GEGLU");
2975
+ variant += "_geglu";
2976
+ break;
2977
+ case GGML_GLU_OP_SWIGLU:
2978
+ defines.push_back("OP_SWIGLU");
2979
+ variant += "_swiglu";
2980
+ break;
2981
+ case GGML_GLU_OP_SWIGLU_OAI:
2982
+ defines.push_back("OP_SWIGLU_OAI");
2983
+ variant += "_swiglu_oai";
2984
+ break;
2985
+ case GGML_GLU_OP_GEGLU_ERF:
2986
+ defines.push_back("OP_GEGLU_ERF");
2987
+ variant += "_geglu_erf";
2988
+ break;
2989
+ case GGML_GLU_OP_GEGLU_QUICK:
2990
+ defines.push_back("OP_GEGLU_QUICK");
2991
+ variant += "_geglu_quick";
2992
+ break;
2993
+ default:
2994
+ GGML_ABORT("Unsupported GLU op");
2995
+ }
2996
+ switch (key.type) {
2997
+ case GGML_TYPE_F32:
2998
+ defines.push_back("TYPE_F32");
2999
+ variant += "_f32";
3000
+ break;
3001
+ case GGML_TYPE_F16:
3002
+ defines.push_back("TYPE_F16");
3003
+ variant += "_f16";
3004
+ break;
3005
+ default:
3006
+ GGML_ABORT("Unsupported type for GLU shader");
3007
+ }
3008
+
3009
+ if (key.split) {
3010
+ variant += "_split";
3011
+ } else {
3012
+ defines.push_back("NO_SPLIT");
3013
+ }
3014
+
3015
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3016
+
3017
+ auto processed = preprocessor.preprocess(wgsl_glu, defines);
3018
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3019
+ decisions->wg_size = context.max_wg_size;
3020
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3021
+ pipeline.context = decisions;
3022
+ glu_pipelines[key] = pipeline;
3023
+ return glu_pipelines[key];
3024
+ }
3025
+
3026
+ webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
3027
+ ggml_webgpu_rope_pipeline_key key = {};
3028
+ key.type = context.dst->type;
3029
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
3030
+ key.has_ff = (context.src2 != nullptr);
3031
+
3032
+ auto it = rope_pipelines.find(key);
3033
+ if (it != rope_pipelines.end()) {
3034
+ return it->second;
3035
+ }
3036
+
3037
+ std::vector<std::string> defines;
3038
+ std::string variant = "rope";
3039
+
3040
+ switch (key.type) {
3041
+ case GGML_TYPE_F32:
3042
+ defines.push_back("TYPE_F32");
3043
+ variant += "_f32";
3044
+ break;
3045
+ case GGML_TYPE_F16:
3046
+ defines.push_back("TYPE_F16");
3047
+ variant += "_f16";
3048
+ break;
3049
+ default:
3050
+ GGML_ABORT("Unsupported type for ROPE shader");
3051
+ }
3052
+
3053
+ if (key.inplace) {
3054
+ defines.push_back("INPLACE");
3055
+ variant += "_inplace";
3056
+ }
3057
+
3058
+ if (key.has_ff) {
3059
+ defines.push_back("FF_FUNC");
3060
+ variant += "_ff";
3061
+ }
3062
+
3063
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3064
+
3065
+ auto processed = preprocessor.preprocess(wgsl_rope, defines);
3066
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3067
+ decisions->wg_size = context.max_wg_size;
3068
+ decisions->inplace = key.inplace;
3069
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3070
+ pipeline.context = decisions;
3071
+ rope_pipelines[key] = pipeline;
3072
+ return rope_pipelines[key];
3073
+ }
3074
+
3075
+ webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
3076
+ ggml_webgpu_soft_max_pipeline_key key = {};
3077
+ key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
3078
+ key.has_mask = (context.src1 != nullptr);
3079
+ key.has_sink = (context.src2 != nullptr);
3080
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
3081
+
3082
+ auto it = soft_max_pipelines.find(key);
3083
+ if (it != soft_max_pipelines.end()) {
3084
+ return it->second;
3085
+ }
3086
+
3087
+ std::vector<std::string> defines;
3088
+ std::string variant = "soft_max";
3089
+
3090
+ if (key.has_mask) {
3091
+ defines.push_back("HAS_MASK");
3092
+ switch (key.mask_type) {
3093
+ case GGML_TYPE_F32:
3094
+ defines.push_back("MASK_F32");
3095
+ variant += "_mask_f32";
3096
+ break;
3097
+ case GGML_TYPE_F16:
3098
+ defines.push_back("MASK_F16");
3099
+ variant += "_mask_f16";
3100
+ break;
3101
+ default:
3102
+ GGML_ABORT("Unsupported type for SOFT_MAX shader");
3103
+ }
3104
+ }
3105
+
3106
+ if (key.has_sink) {
3107
+ defines.push_back("HAS_SINK");
3108
+ variant += "_sink";
3109
+ }
3110
+
3111
+ if (key.inplace) {
3112
+ defines.push_back("INPLACE");
3113
+ variant += "_inplace";
3114
+ }
3115
+
3116
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3117
+
3118
+ auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
3119
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3120
+ decisions->wg_size = context.max_wg_size;
3121
+ decisions->inplace = key.inplace;
3122
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3123
+ pipeline.context = decisions;
3124
+ soft_max_pipelines[key] = pipeline;
3125
+ return soft_max_pipelines[key];
3126
+ }
3127
+
3128
+ webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) {
3129
+ ggml_webgpu_conv2d_pipeline_key key = {};
3130
+ key.weight_type = context.src0->type;
3131
+ key.input_type = context.src1->type;
3132
+ key.output_type = context.dst->type;
3133
+
3134
+ auto it = conv2d_pipelines.find(key);
3135
+ if (it != conv2d_pipelines.end()) {
3136
+ return it->second;
3137
+ }
3138
+
3139
+ std::vector<std::string> defines;
3140
+ std::string variant = "conv_2d";
3141
+
3142
+ auto push_type_defines = [&](const char * prefix, ggml_type type) {
3143
+ std::string s_prefix = prefix;
3144
+ if (type == GGML_TYPE_F32) {
3145
+ defines.push_back(s_prefix + "_F32");
3146
+ } else if (type == GGML_TYPE_F16) {
3147
+ defines.push_back(s_prefix + "_F16");
3148
+ } else {
3149
+ GGML_ABORT("Unsupported type for CONV_2D shader");
3150
+ }
3151
+ };
3152
+
3153
+ push_type_defines("WEIGHT", key.weight_type);
3154
+ push_type_defines("INPUT", key.input_type);
3155
+ push_type_defines("OUTPUT", key.output_type);
3156
+
3157
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3158
+
3159
+ auto processed = preprocessor.preprocess(wgsl_conv2d, defines);
3160
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3161
+ decisions->wg_size = context.max_wg_size;
3162
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3163
+ pipeline.context = decisions;
3164
+ conv2d_pipelines[key] = pipeline;
3165
+ return conv2d_pipelines[key];
3166
+ }
3167
+
3168
+ webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) {
3169
+ ggml_webgpu_im2col_pipeline_key key = {};
3170
+ key.input_type = context.src1->type;
3171
+ key.output_type = context.dst->type;
3172
+
3173
+ auto it = im2col_pipelines.find(key);
3174
+ if (it != im2col_pipelines.end()) {
3175
+ return it->second;
3176
+ }
3177
+
3178
+ std::vector<std::string> defines;
3179
+ std::string variant = "im2col";
3180
+
3181
+ auto push_type_defines = [&](const char * prefix, ggml_type type) {
3182
+ std::string s_prefix = prefix;
3183
+ if (type == GGML_TYPE_F32) {
3184
+ defines.push_back(s_prefix + "_F32");
3185
+ } else if (type == GGML_TYPE_F16) {
3186
+ defines.push_back(s_prefix + "_F16");
3187
+ } else {
3188
+ GGML_ABORT("Unsupported type for IM2COL shader");
3189
+ }
3190
+ };
3191
+
3192
+ push_type_defines("INPUT", key.input_type);
3193
+ push_type_defines("OUTPUT", key.output_type);
3194
+
3195
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3196
+
3197
+ auto processed = preprocessor.preprocess(wgsl_im2col, defines);
3198
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3199
+ decisions->wg_size = context.max_wg_size;
3200
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3201
+ pipeline.context = decisions;
3202
+ im2col_pipelines[key] = pipeline;
3203
+ return im2col_pipelines[key];
3204
+ }
3205
+
3206
+ webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) {
3207
+ const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0);
3208
+ const uint32_t base_mode = mode_flags & 0xFFu;
3209
+ const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u;
3210
+
3211
+ ggml_webgpu_upscale_pipeline_key key = {};
3212
+ key.input_type = context.src0->type;
3213
+ key.output_type = context.dst->type;
3214
+ key.base_mode = base_mode;
3215
+ key.antialias = antialias;
3216
+
3217
+ auto it = upscale_pipelines.find(key);
3218
+ if (it != upscale_pipelines.end()) {
3219
+ return it->second;
3220
+ }
3221
+
3222
+ std::vector<std::string> defines;
3223
+ std::string variant = "upscale";
3224
+
3225
+ if (key.input_type == GGML_TYPE_F16) {
3226
+ defines.push_back("SRC_F16");
3227
+ variant += "_src_f16";
3228
+ } else {
3229
+ variant += "_src_f32";
3230
+ }
3231
+
3232
+ if (key.output_type == GGML_TYPE_F16) {
3233
+ defines.push_back("DST_F16");
3234
+ variant += "_dst_f16";
3235
+ } else {
3236
+ variant += "_dst_f32";
3237
+ }
3238
+
3239
+ switch (base_mode) {
3240
+ case GGML_SCALE_MODE_NEAREST:
3241
+ defines.push_back("NEAREST");
3242
+ variant += "_nearest";
3243
+ break;
3244
+ case GGML_SCALE_MODE_BILINEAR:
3245
+ defines.push_back("BILINEAR");
3246
+ variant += "_bilinear";
3247
+ break;
3248
+ case GGML_SCALE_MODE_BICUBIC:
3249
+ defines.push_back("BICUBIC");
3250
+ variant += "_bicubic";
3251
+ break;
3252
+ default:
3253
+ GGML_ABORT("Unsupported upscale mode");
3254
+ }
3255
+
3256
+ if (antialias) {
3257
+ defines.push_back("ANTIALIAS");
3258
+ variant += "_aa";
3259
+ }
3260
+
3261
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3262
+
3263
+ auto processed = preprocessor.preprocess(wgsl_upscale, defines);
3264
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3265
+ decisions->wg_size = context.max_wg_size;
3266
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3267
+ pipeline.context = decisions;
3268
+ upscale_pipelines[key] = pipeline;
3269
+ return upscale_pipelines[key];
3270
+ }
3271
+
3272
+ private:
3273
+ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
3274
+ std::string shader_code,
3275
+ std::string label) {
3276
+ wgpu::ShaderSourceWGSL shader_source;
3277
+ shader_source.code = shader_code.c_str();
3278
+
3279
+ wgpu::ShaderModuleDescriptor shader_desc;
3280
+ shader_desc.nextInChain = &shader_source;
3281
+
3282
+ wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
3283
+
3284
+ wgpu::ComputePipelineDescriptor pipeline_desc;
3285
+ pipeline_desc.label = label.c_str();
3286
+ pipeline_desc.compute.module = shader_module;
3287
+ pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
3288
+ pipeline_desc.layout = nullptr; // nullptr means auto layout
3289
+ return { device.CreateComputePipeline(&pipeline_desc), label };
3290
+ }
3291
+ };
3292
+
169
3293
  #endif // GGML_WEBGPU_SHADER_LIB_HPP