local-llm-rn 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (626) hide show
  1. package/cpp/CMakeLists.txt +285 -0
  2. package/cpp/common/CMakeLists.txt +149 -0
  3. package/cpp/common/arg.cpp +3799 -0
  4. package/cpp/common/arg.h +131 -0
  5. package/cpp/common/base64.hpp +392 -0
  6. package/cpp/common/build-info.cpp.in +4 -0
  7. package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
  8. package/cpp/common/chat-parser-xml-toolcall.h +45 -0
  9. package/cpp/common/chat-parser.cpp +1649 -0
  10. package/cpp/common/chat-parser.h +133 -0
  11. package/cpp/common/chat-peg-parser.cpp +124 -0
  12. package/cpp/common/chat-peg-parser.h +105 -0
  13. package/cpp/common/chat.cpp +3355 -0
  14. package/cpp/common/chat.h +252 -0
  15. package/cpp/common/common.cpp +1824 -0
  16. package/cpp/common/common.h +930 -0
  17. package/cpp/common/console.cpp +1137 -0
  18. package/cpp/common/console.h +41 -0
  19. package/cpp/common/debug.cpp +167 -0
  20. package/cpp/common/debug.h +43 -0
  21. package/cpp/common/download.cpp +792 -0
  22. package/cpp/common/download.h +84 -0
  23. package/cpp/common/http.h +84 -0
  24. package/cpp/common/jinja/README.md +88 -0
  25. package/cpp/common/jinja/caps.cpp +285 -0
  26. package/cpp/common/jinja/caps.h +30 -0
  27. package/cpp/common/jinja/lexer.cpp +341 -0
  28. package/cpp/common/jinja/lexer.h +157 -0
  29. package/cpp/common/jinja/parser.cpp +591 -0
  30. package/cpp/common/jinja/parser.h +21 -0
  31. package/cpp/common/jinja/runtime.cpp +867 -0
  32. package/cpp/common/jinja/runtime.h +638 -0
  33. package/cpp/common/jinja/string.cpp +213 -0
  34. package/cpp/common/jinja/string.h +61 -0
  35. package/cpp/common/jinja/utils.h +149 -0
  36. package/cpp/common/jinja/value.cpp +1393 -0
  37. package/cpp/common/jinja/value.h +756 -0
  38. package/cpp/common/json-partial.cpp +324 -0
  39. package/cpp/common/json-partial.h +39 -0
  40. package/cpp/common/json-schema-to-grammar.cpp +1153 -0
  41. package/cpp/common/json-schema-to-grammar.h +43 -0
  42. package/cpp/common/llguidance.cpp +258 -0
  43. package/cpp/common/log.cpp +446 -0
  44. package/cpp/common/log.h +119 -0
  45. package/cpp/common/ngram-cache.cpp +285 -0
  46. package/cpp/common/ngram-cache.h +101 -0
  47. package/cpp/common/ngram-map.cpp +530 -0
  48. package/cpp/common/ngram-map.h +115 -0
  49. package/cpp/common/ngram-mod.cpp +60 -0
  50. package/cpp/common/ngram-mod.h +38 -0
  51. package/cpp/common/peg-parser.cpp +1712 -0
  52. package/cpp/common/peg-parser.h +459 -0
  53. package/cpp/common/preset.cpp +483 -0
  54. package/cpp/common/preset.h +83 -0
  55. package/cpp/common/regex-partial.cpp +204 -0
  56. package/cpp/common/regex-partial.h +56 -0
  57. package/cpp/common/sampling.cpp +745 -0
  58. package/cpp/common/sampling.h +119 -0
  59. package/cpp/common/speculative.cpp +1074 -0
  60. package/cpp/common/speculative.h +41 -0
  61. package/cpp/common/unicode.cpp +64 -0
  62. package/cpp/common/unicode.h +22 -0
  63. package/cpp/ggml/CMakeLists.txt +494 -0
  64. package/cpp/ggml/cmake/GitVars.cmake +22 -0
  65. package/cpp/ggml/cmake/common.cmake +50 -0
  66. package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
  67. package/cpp/ggml/include/ggml-alloc.h +85 -0
  68. package/cpp/ggml/include/ggml-backend.h +373 -0
  69. package/cpp/ggml/include/ggml-blas.h +25 -0
  70. package/cpp/ggml/include/ggml-cann.h +123 -0
  71. package/cpp/ggml/include/ggml-cpp.h +39 -0
  72. package/cpp/ggml/include/ggml-cpu.h +151 -0
  73. package/cpp/ggml/include/ggml-cuda.h +47 -0
  74. package/cpp/ggml/include/ggml-hexagon.h +19 -0
  75. package/cpp/ggml/include/ggml-metal.h +61 -0
  76. package/cpp/ggml/include/ggml-opencl.h +26 -0
  77. package/cpp/ggml/include/ggml-opt.h +256 -0
  78. package/cpp/ggml/include/ggml-rpc.h +30 -0
  79. package/cpp/ggml/include/ggml-sycl.h +49 -0
  80. package/cpp/ggml/include/ggml-virtgpu.h +14 -0
  81. package/cpp/ggml/include/ggml-vulkan.h +29 -0
  82. package/cpp/ggml/include/ggml-webgpu.h +19 -0
  83. package/cpp/ggml/include/ggml-zdnn.h +17 -0
  84. package/cpp/ggml/include/ggml-zendnn.h +22 -0
  85. package/cpp/ggml/include/ggml.h +2753 -0
  86. package/cpp/ggml/include/gguf.h +204 -0
  87. package/cpp/ggml/src/CMakeLists.txt +492 -0
  88. package/cpp/ggml/src/ggml-alloc.c +1244 -0
  89. package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
  90. package/cpp/ggml/src/ggml-backend-dl.h +45 -0
  91. package/cpp/ggml/src/ggml-backend-impl.h +255 -0
  92. package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
  93. package/cpp/ggml/src/ggml-backend.cpp +2270 -0
  94. package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
  95. package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
  96. package/cpp/ggml/src/ggml-common.h +1878 -0
  97. package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
  98. package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
  99. package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  100. package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  101. package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
  102. package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  103. package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
  104. package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
  105. package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
  106. package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
  107. package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  108. package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  109. package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  110. package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
  111. package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  112. package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  113. package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  114. package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  115. package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  116. package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  117. package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  118. package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
  119. package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
  120. package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  121. package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  122. package/cpp/ggml/src/ggml-cpu/common.h +95 -0
  123. package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
  124. package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
  125. package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
  126. package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
  127. package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
  128. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
  129. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
  130. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
  131. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  132. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
  133. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
  134. package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
  135. package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
  136. package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
  137. package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
  138. package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
  139. package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
  140. package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  141. package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
  142. package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  143. package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  144. package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  145. package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  146. package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
  147. package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
  148. package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
  149. package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
  150. package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
  151. package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
  152. package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  153. package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
  154. package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
  155. package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
  156. package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  157. package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
  158. package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  159. package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
  160. package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
  161. package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
  162. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  163. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  164. package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  165. package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  166. package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  167. package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  168. package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
  169. package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
  170. package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  171. package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
  172. package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
  173. package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  174. package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
  175. package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  176. package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  177. package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  178. package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
  179. package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
  180. package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  181. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  182. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  183. package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  184. package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
  185. package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
  186. package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
  187. package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
  188. package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
  189. package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
  190. package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
  191. package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
  192. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  193. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  194. package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  195. package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  196. package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
  197. package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  198. package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
  199. package/cpp/ggml/src/ggml-impl.h +724 -0
  200. package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
  201. package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
  202. package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  203. package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
  204. package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
  205. package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
  206. package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
  207. package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
  208. package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
  209. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
  210. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  211. package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
  212. package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
  213. package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
  214. package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  215. package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  216. package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
  217. package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
  218. package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
  219. package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  220. package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  221. package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  222. package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
  223. package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  224. package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  225. package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  226. package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
  227. package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  228. package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  229. package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  230. package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  231. package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  232. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  233. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  234. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  235. package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
  236. package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  237. package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  238. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  239. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  240. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  241. package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
  242. package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  243. package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  244. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  245. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  246. package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  247. package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
  248. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  249. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  250. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  251. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  252. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  253. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  254. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  255. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  256. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  257. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  258. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  259. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  260. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  261. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  262. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  263. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  264. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  265. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  266. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  267. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  268. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  269. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  270. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  271. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  272. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  273. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  274. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  275. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  276. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  277. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  278. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
  279. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  280. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  281. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  282. package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
  283. package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
  284. package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  285. package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
  286. package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
  287. package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
  288. package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
  289. package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  290. package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  291. package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  292. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
  293. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
  294. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
  295. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
  296. package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  297. package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  298. package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  299. package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  300. package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  301. package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  302. package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
  303. package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
  304. package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
  305. package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  306. package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  307. package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  308. package/cpp/ggml/src/ggml-opt.cpp +1093 -0
  309. package/cpp/ggml/src/ggml-quants.c +5325 -0
  310. package/cpp/ggml/src/ggml-quants.h +106 -0
  311. package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  312. package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
  313. package/cpp/ggml/src/ggml-threading.cpp +12 -0
  314. package/cpp/ggml/src/ggml-threading.h +14 -0
  315. package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  316. package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  317. package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  318. package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  319. package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  320. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  321. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  322. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  323. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  324. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  325. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  326. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  327. package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  328. package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  329. package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  330. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  331. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  332. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  333. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  334. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  335. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  336. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  337. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  338. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  339. package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  340. package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  341. package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  342. package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  343. package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  344. package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  345. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  346. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  347. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  348. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  349. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  350. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  351. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  352. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  353. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  354. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  355. package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  356. package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  357. package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  358. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
  359. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
  360. package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  361. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  362. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  363. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  364. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
  365. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
  366. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  367. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  368. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
  369. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  370. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
  371. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  372. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  373. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
  374. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
  375. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
  376. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
  377. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
  378. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  379. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  380. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  381. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  382. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
  383. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  384. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  385. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  386. package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  387. package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  388. package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
  389. package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  390. package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  391. package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  392. package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  393. package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  394. package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  395. package/cpp/ggml/src/ggml.c +7669 -0
  396. package/cpp/ggml/src/ggml.cpp +26 -0
  397. package/cpp/ggml/src/gguf.cpp +1699 -0
  398. package/cpp/include/llama-cpp.h +32 -0
  399. package/cpp/include/llama.h +1568 -0
  400. package/cpp/mtmd/CMakeLists.txt +98 -0
  401. package/cpp/mtmd/README.md +63 -0
  402. package/cpp/mtmd/clip-graph.h +117 -0
  403. package/cpp/mtmd/clip-impl.h +586 -0
  404. package/cpp/mtmd/clip-model.h +390 -0
  405. package/cpp/mtmd/clip.cpp +4154 -0
  406. package/cpp/mtmd/clip.h +121 -0
  407. package/cpp/mtmd/deprecation-warning.cpp +22 -0
  408. package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
  409. package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
  410. package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
  411. package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
  412. package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
  413. package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
  414. package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
  415. package/cpp/mtmd/models/cogvlm.cpp +98 -0
  416. package/cpp/mtmd/models/conformer.cpp +216 -0
  417. package/cpp/mtmd/models/glm4v.cpp +122 -0
  418. package/cpp/mtmd/models/internvl.cpp +69 -0
  419. package/cpp/mtmd/models/kimik25.cpp +101 -0
  420. package/cpp/mtmd/models/kimivl.cpp +63 -0
  421. package/cpp/mtmd/models/llama4.cpp +96 -0
  422. package/cpp/mtmd/models/llava.cpp +374 -0
  423. package/cpp/mtmd/models/minicpmv.cpp +114 -0
  424. package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
  425. package/cpp/mtmd/models/models.h +128 -0
  426. package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
  427. package/cpp/mtmd/models/paddleocr.cpp +52 -0
  428. package/cpp/mtmd/models/pixtral.cpp +86 -0
  429. package/cpp/mtmd/models/qwen2vl.cpp +183 -0
  430. package/cpp/mtmd/models/qwen3vl.cpp +193 -0
  431. package/cpp/mtmd/models/siglip.cpp +86 -0
  432. package/cpp/mtmd/models/whisper-enc.cpp +115 -0
  433. package/cpp/mtmd/models/youtuvl.cpp +179 -0
  434. package/cpp/mtmd/mtmd-audio.cpp +730 -0
  435. package/cpp/mtmd/mtmd-audio.h +113 -0
  436. package/cpp/mtmd/mtmd-cli.cpp +437 -0
  437. package/cpp/mtmd/mtmd-helper.cpp +521 -0
  438. package/cpp/mtmd/mtmd-helper.h +96 -0
  439. package/cpp/mtmd/mtmd.cpp +1156 -0
  440. package/cpp/mtmd/mtmd.h +319 -0
  441. package/cpp/mtmd/requirements.txt +5 -0
  442. package/cpp/mtmd/test-1.jpeg +0 -0
  443. package/cpp/mtmd/test-2.mp3 +0 -0
  444. package/cpp/mtmd/tests.sh +192 -0
  445. package/cpp/src/CMakeLists.txt +169 -0
  446. package/cpp/src/llama-adapter.cpp +488 -0
  447. package/cpp/src/llama-adapter.h +89 -0
  448. package/cpp/src/llama-arch.cpp +2855 -0
  449. package/cpp/src/llama-arch.h +619 -0
  450. package/cpp/src/llama-batch.cpp +917 -0
  451. package/cpp/src/llama-batch.h +173 -0
  452. package/cpp/src/llama-chat.cpp +896 -0
  453. package/cpp/src/llama-chat.h +71 -0
  454. package/cpp/src/llama-context.cpp +3512 -0
  455. package/cpp/src/llama-context.h +359 -0
  456. package/cpp/src/llama-cparams.cpp +5 -0
  457. package/cpp/src/llama-cparams.h +44 -0
  458. package/cpp/src/llama-grammar.cpp +1464 -0
  459. package/cpp/src/llama-grammar.h +194 -0
  460. package/cpp/src/llama-graph.cpp +2685 -0
  461. package/cpp/src/llama-graph.h +1026 -0
  462. package/cpp/src/llama-hparams.cpp +234 -0
  463. package/cpp/src/llama-hparams.h +339 -0
  464. package/cpp/src/llama-impl.cpp +171 -0
  465. package/cpp/src/llama-impl.h +73 -0
  466. package/cpp/src/llama-io.cpp +15 -0
  467. package/cpp/src/llama-io.h +35 -0
  468. package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
  469. package/cpp/src/llama-kv-cache-iswa.h +137 -0
  470. package/cpp/src/llama-kv-cache.cpp +2271 -0
  471. package/cpp/src/llama-kv-cache.h +388 -0
  472. package/cpp/src/llama-kv-cells.h +533 -0
  473. package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
  474. package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
  475. package/cpp/src/llama-memory-hybrid.cpp +268 -0
  476. package/cpp/src/llama-memory-hybrid.h +139 -0
  477. package/cpp/src/llama-memory-recurrent.cpp +1165 -0
  478. package/cpp/src/llama-memory-recurrent.h +182 -0
  479. package/cpp/src/llama-memory.cpp +59 -0
  480. package/cpp/src/llama-memory.h +122 -0
  481. package/cpp/src/llama-mmap.cpp +785 -0
  482. package/cpp/src/llama-mmap.h +92 -0
  483. package/cpp/src/llama-model-loader.cpp +1414 -0
  484. package/cpp/src/llama-model-loader.h +203 -0
  485. package/cpp/src/llama-model-saver.cpp +286 -0
  486. package/cpp/src/llama-model-saver.h +37 -0
  487. package/cpp/src/llama-model.cpp +9253 -0
  488. package/cpp/src/llama-model.h +576 -0
  489. package/cpp/src/llama-quant.cpp +1119 -0
  490. package/cpp/src/llama-quant.h +1 -0
  491. package/cpp/src/llama-sampler.cpp +3885 -0
  492. package/cpp/src/llama-sampler.h +42 -0
  493. package/cpp/src/llama-vocab.cpp +3970 -0
  494. package/cpp/src/llama-vocab.h +187 -0
  495. package/cpp/src/llama.cpp +1313 -0
  496. package/cpp/src/models/afmoe.cpp +191 -0
  497. package/cpp/src/models/apertus.cpp +125 -0
  498. package/cpp/src/models/arcee.cpp +135 -0
  499. package/cpp/src/models/arctic.cpp +138 -0
  500. package/cpp/src/models/arwkv7.cpp +86 -0
  501. package/cpp/src/models/baichuan.cpp +122 -0
  502. package/cpp/src/models/bailingmoe.cpp +144 -0
  503. package/cpp/src/models/bailingmoe2.cpp +135 -0
  504. package/cpp/src/models/bert.cpp +178 -0
  505. package/cpp/src/models/bitnet.cpp +160 -0
  506. package/cpp/src/models/bloom.cpp +101 -0
  507. package/cpp/src/models/chameleon.cpp +178 -0
  508. package/cpp/src/models/chatglm.cpp +132 -0
  509. package/cpp/src/models/codeshell.cpp +111 -0
  510. package/cpp/src/models/cogvlm.cpp +102 -0
  511. package/cpp/src/models/cohere2-iswa.cpp +134 -0
  512. package/cpp/src/models/command-r.cpp +122 -0
  513. package/cpp/src/models/dbrx.cpp +123 -0
  514. package/cpp/src/models/deci.cpp +135 -0
  515. package/cpp/src/models/deepseek.cpp +144 -0
  516. package/cpp/src/models/deepseek2.cpp +262 -0
  517. package/cpp/src/models/delta-net-base.cpp +376 -0
  518. package/cpp/src/models/dots1.cpp +134 -0
  519. package/cpp/src/models/dream.cpp +105 -0
  520. package/cpp/src/models/ernie4-5-moe.cpp +150 -0
  521. package/cpp/src/models/ernie4-5.cpp +110 -0
  522. package/cpp/src/models/eurobert.cpp +97 -0
  523. package/cpp/src/models/exaone-moe.cpp +146 -0
  524. package/cpp/src/models/exaone.cpp +114 -0
  525. package/cpp/src/models/exaone4.cpp +123 -0
  526. package/cpp/src/models/falcon-h1.cpp +111 -0
  527. package/cpp/src/models/falcon.cpp +120 -0
  528. package/cpp/src/models/gemma-embedding.cpp +116 -0
  529. package/cpp/src/models/gemma.cpp +112 -0
  530. package/cpp/src/models/gemma2-iswa.cpp +128 -0
  531. package/cpp/src/models/gemma3.cpp +155 -0
  532. package/cpp/src/models/gemma3n-iswa.cpp +384 -0
  533. package/cpp/src/models/glm4-moe.cpp +170 -0
  534. package/cpp/src/models/glm4.cpp +157 -0
  535. package/cpp/src/models/gpt2.cpp +105 -0
  536. package/cpp/src/models/gptneox.cpp +144 -0
  537. package/cpp/src/models/granite-hybrid.cpp +196 -0
  538. package/cpp/src/models/granite.cpp +211 -0
  539. package/cpp/src/models/grok.cpp +159 -0
  540. package/cpp/src/models/grovemoe.cpp +141 -0
  541. package/cpp/src/models/hunyuan-dense.cpp +132 -0
  542. package/cpp/src/models/hunyuan-moe.cpp +154 -0
  543. package/cpp/src/models/internlm2.cpp +120 -0
  544. package/cpp/src/models/jais.cpp +86 -0
  545. package/cpp/src/models/jais2.cpp +123 -0
  546. package/cpp/src/models/jamba.cpp +106 -0
  547. package/cpp/src/models/kimi-linear.cpp +392 -0
  548. package/cpp/src/models/lfm2.cpp +190 -0
  549. package/cpp/src/models/llada-moe.cpp +122 -0
  550. package/cpp/src/models/llada.cpp +99 -0
  551. package/cpp/src/models/llama-iswa.cpp +178 -0
  552. package/cpp/src/models/llama.cpp +168 -0
  553. package/cpp/src/models/maincoder.cpp +117 -0
  554. package/cpp/src/models/mamba-base.cpp +285 -0
  555. package/cpp/src/models/mamba.cpp +54 -0
  556. package/cpp/src/models/mimo2-iswa.cpp +123 -0
  557. package/cpp/src/models/minicpm3.cpp +200 -0
  558. package/cpp/src/models/minimax-m2.cpp +124 -0
  559. package/cpp/src/models/mistral3.cpp +160 -0
  560. package/cpp/src/models/models.h +684 -0
  561. package/cpp/src/models/modern-bert.cpp +109 -0
  562. package/cpp/src/models/mpt.cpp +126 -0
  563. package/cpp/src/models/nemotron-h.cpp +148 -0
  564. package/cpp/src/models/nemotron.cpp +122 -0
  565. package/cpp/src/models/neo-bert.cpp +104 -0
  566. package/cpp/src/models/olmo.cpp +121 -0
  567. package/cpp/src/models/olmo2.cpp +150 -0
  568. package/cpp/src/models/olmoe.cpp +124 -0
  569. package/cpp/src/models/openai-moe-iswa.cpp +127 -0
  570. package/cpp/src/models/openelm.cpp +124 -0
  571. package/cpp/src/models/orion.cpp +123 -0
  572. package/cpp/src/models/paddleocr.cpp +122 -0
  573. package/cpp/src/models/pangu-embedded.cpp +121 -0
  574. package/cpp/src/models/phi2.cpp +121 -0
  575. package/cpp/src/models/phi3.cpp +152 -0
  576. package/cpp/src/models/plamo.cpp +110 -0
  577. package/cpp/src/models/plamo2.cpp +318 -0
  578. package/cpp/src/models/plamo3.cpp +128 -0
  579. package/cpp/src/models/plm.cpp +169 -0
  580. package/cpp/src/models/qwen.cpp +108 -0
  581. package/cpp/src/models/qwen2.cpp +126 -0
  582. package/cpp/src/models/qwen2moe.cpp +151 -0
  583. package/cpp/src/models/qwen2vl.cpp +117 -0
  584. package/cpp/src/models/qwen3.cpp +117 -0
  585. package/cpp/src/models/qwen35.cpp +386 -0
  586. package/cpp/src/models/qwen35moe.cpp +420 -0
  587. package/cpp/src/models/qwen3moe.cpp +124 -0
  588. package/cpp/src/models/qwen3next.cpp +525 -0
  589. package/cpp/src/models/qwen3vl-moe.cpp +140 -0
  590. package/cpp/src/models/qwen3vl.cpp +132 -0
  591. package/cpp/src/models/refact.cpp +94 -0
  592. package/cpp/src/models/rnd1.cpp +126 -0
  593. package/cpp/src/models/rwkv6-base.cpp +164 -0
  594. package/cpp/src/models/rwkv6.cpp +94 -0
  595. package/cpp/src/models/rwkv6qwen2.cpp +86 -0
  596. package/cpp/src/models/rwkv7-base.cpp +137 -0
  597. package/cpp/src/models/rwkv7.cpp +90 -0
  598. package/cpp/src/models/seed-oss.cpp +124 -0
  599. package/cpp/src/models/smallthinker.cpp +126 -0
  600. package/cpp/src/models/smollm3.cpp +128 -0
  601. package/cpp/src/models/stablelm.cpp +146 -0
  602. package/cpp/src/models/starcoder.cpp +100 -0
  603. package/cpp/src/models/starcoder2.cpp +121 -0
  604. package/cpp/src/models/step35-iswa.cpp +168 -0
  605. package/cpp/src/models/t5-dec.cpp +166 -0
  606. package/cpp/src/models/t5-enc.cpp +96 -0
  607. package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
  608. package/cpp/src/models/xverse.cpp +108 -0
  609. package/cpp/src/unicode-data.cpp +7034 -0
  610. package/cpp/src/unicode-data.h +20 -0
  611. package/cpp/src/unicode.cpp +1103 -0
  612. package/cpp/src/unicode.h +111 -0
  613. package/cpp/vendor/nlohmann/json.hpp +25526 -0
  614. package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  615. package/cpp/vendor/stb/stb_image.h +7988 -0
  616. package/ios/LocalLLM-Bridging-Header.h +2 -0
  617. package/ios/LocalLLM.h +5 -0
  618. package/ios/LocalLLM.mm +1267 -0
  619. package/local-llm-rn.podspec +60 -0
  620. package/package.json +35 -0
  621. package/src/NativeLocalLLM.ts +73 -0
  622. package/src/device.ts +50 -0
  623. package/src/download-adapter.ts +17 -0
  624. package/src/index.ts +21 -0
  625. package/src/native-bridge.ts +142 -0
  626. package/src/rn-downloader.ts +37 -0
@@ -0,0 +1,3150 @@
1
+ /*
2
+ WebGPU backend implementation.
3
+ Note: Use ClangFormat to format this file.
4
+ */
5
+
6
+ #include "ggml-webgpu.h"
7
+
8
+ #include "ggml-backend-impl.h"
9
+ #include "ggml-impl.h"
10
+ #include "ggml-webgpu-shader-lib.hpp"
11
+ #include "pre_wgsl.hpp"
12
+
13
+ #ifdef __EMSCRIPTEN__
14
+ # include <emscripten/emscripten.h>
15
+ #endif
16
+
17
+ #include <webgpu/webgpu_cpp.h>
18
+
19
+ #include <atomic>
20
+ #include <condition_variable>
21
+ #include <cstdint>
22
+ #include <cstring>
23
+ #include <iostream>
24
+ #include <map>
25
+ #include <memory>
26
+ #include <mutex>
27
+ #include <optional>
28
+ #include <string>
29
+ #include <vector>
30
+
31
+ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
32
+ #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
33
+
34
+ #ifdef GGML_WEBGPU_DEBUG
35
+ # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
36
+ # define WEBGPU_DEBUG_BUF_ELEMS 512
37
+ #else
38
+ # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
39
+ #endif // GGML_WEBGPU_DEBUG
40
+
41
+ #ifdef GGML_WEBGPU_CPU_PROFILE
42
+ // total timing (aggregated)
43
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
44
+
45
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
46
+ auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
47
+ double cpu_total_time_##id = \
48
+ std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
49
+ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
50
+ // fine-grained timing (not included in totals)
51
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
52
+
53
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
54
+ auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
55
+ double cpu_detail_time_##id = \
56
+ std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
57
+ (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
58
+ #else
59
+ # define WEBGPU_CPU_PROFILE_TOTAL_START(id)
60
+ # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
61
+ # define WEBGPU_CPU_PROFILE_DETAIL_START(id)
62
+ # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
63
+ #endif // GGML_WEBGPU_CPU_PROFILE
64
+
65
+ #ifdef GGML_WEBGPU_GPU_PROFILE
66
+ # define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
67
+ # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
68
+ #endif
69
+
70
+ /* Constants */
71
+
72
+ #define WEBGPU_NUM_PARAM_BUFS 16u
73
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
74
+ #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
75
+ // Maximum number of in-flight submissions per-thread, to avoid exhausting the
76
+ // parameter buffer pool
77
+ #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
78
+ #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
79
+ #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
80
+ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
81
+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
82
+
83
+ // For operations which process a row in parallel, this seems like a reasonable
84
+ // default
85
+ #define WEBGPU_ROW_SPLIT_WG_SIZE 64
86
+
87
+ // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
88
+ // implementations so this can be removed, necessary only for get_rows right now
89
+ #define WEBGPU_MAX_WG_SIZE 288
90
+
91
+ /* End Constants */
92
+
93
+ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to
94
+ // their locations.
95
+ static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
96
+
97
+ // Always returns the base offset of a tensor, regardless of views.
98
+ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
99
+ if (tensor->view_src) {
100
+ return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
101
+ }
102
+ return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
103
+ }
104
+
105
+ /* Struct definitions */
106
+
107
+ // Forward reference
108
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
109
+ wgpu::Buffer & buffer,
110
+ size_t size,
111
+ wgpu::BufferUsage usage,
112
+ const char * label);
113
+
114
+ struct webgpu_pool_bufs {
115
+ wgpu::Buffer host_buf;
116
+ wgpu::Buffer dev_buf;
117
+ };
118
+
119
+ // The futures to wait on for a single queue submission
120
+ struct webgpu_submission_futures {
121
+ std::vector<wgpu::FutureWaitInfo> futures;
122
+ };
123
+
124
+ // Holds a pool of parameter buffers for WebGPU operations
125
+ struct webgpu_buf_pool {
126
+ std::vector<webgpu_pool_bufs> free;
127
+
128
+ // The pool must be synchronized because
129
+ // 1. The memset pool is shared globally by every ggml buffer,
130
+ // since allocating a pool per ggml buffer would consume too much memory.
131
+ // 2. For the per-thread buffer pools in webgpu_context,
132
+ // buffers are allocated and freed in Dawn callbacks,
133
+ // which can run on a different thread than the calling thread.
134
+ std::mutex mutex;
135
+ std::condition_variable cv;
136
+
137
+ void init(wgpu::Device device,
138
+ int num_bufs,
139
+ size_t buf_size,
140
+ wgpu::BufferUsage dev_buf_usage,
141
+ wgpu::BufferUsage host_buf_usage) {
142
+ for (int i = 0; i < num_bufs; i++) {
143
+ wgpu::Buffer host_buf;
144
+ wgpu::Buffer dev_buf;
145
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
146
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
147
+ free.push_back({ host_buf, dev_buf });
148
+ }
149
+ }
150
+
151
+ webgpu_pool_bufs alloc_bufs() {
152
+ std::unique_lock<std::mutex> lock(mutex);
153
+ cv.wait(lock, [this] { return !free.empty(); });
154
+ webgpu_pool_bufs bufs = free.back();
155
+ free.pop_back();
156
+ return bufs;
157
+ }
158
+
159
+ void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
160
+ std::lock_guard<std::mutex> lock(mutex);
161
+ free.insert(free.end(), bufs.begin(), bufs.end());
162
+ cv.notify_all();
163
+ }
164
+
165
+ void cleanup() {
166
+ std::lock_guard<std::mutex> lock(mutex);
167
+ for (auto & bufs : free) {
168
+ if (bufs.host_buf) {
169
+ bufs.host_buf.Destroy();
170
+ }
171
+ if (bufs.dev_buf) {
172
+ bufs.dev_buf.Destroy();
173
+ }
174
+ }
175
+ free.clear();
176
+ }
177
+
178
+ ~webgpu_buf_pool() { this->cleanup(); }
179
+ };
180
+
181
+ #ifdef GGML_WEBGPU_GPU_PROFILE
182
+ struct webgpu_gpu_profile_bufs {
183
+ wgpu::Buffer host_buf;
184
+ wgpu::Buffer dev_buf;
185
+ wgpu::QuerySet query_set;
186
+ };
187
+
188
+ // Holds a pool of parameter buffers for WebGPU operations
189
+ struct webgpu_gpu_profile_buf_pool {
190
+ std::vector<webgpu_gpu_profile_bufs> free;
191
+
192
+ std::mutex mutex;
193
+
194
+ std::condition_variable cv;
195
+
196
+ void init(wgpu::Device device,
197
+ int num_bufs,
198
+ size_t buf_size,
199
+ wgpu::BufferUsage dev_buf_usage,
200
+ wgpu::BufferUsage host_buf_usage) {
201
+ for (int i = 0; i < num_bufs; i++) {
202
+ wgpu::Buffer host_buf;
203
+ wgpu::Buffer dev_buf;
204
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
205
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
206
+ // Create a query set for 2 timestamps
207
+ wgpu::QuerySetDescriptor ts_query_set_desc = {};
208
+
209
+ ts_query_set_desc.type = wgpu::QueryType::Timestamp;
210
+ ts_query_set_desc.count = 2;
211
+ wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
212
+
213
+ free.push_back({ host_buf, dev_buf, ts_query_set });
214
+ }
215
+ }
216
+
217
+ webgpu_gpu_profile_bufs alloc_bufs() {
218
+ std::unique_lock<std::mutex> lock(mutex);
219
+ cv.wait(lock, [this] { return !free.empty(); });
220
+ webgpu_gpu_profile_bufs bufs = free.back();
221
+ free.pop_back();
222
+ return bufs;
223
+ }
224
+
225
+ void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
226
+ std::lock_guard<std::mutex> lock(mutex);
227
+ free.insert(free.end(), bufs.begin(), bufs.end());
228
+ cv.notify_all();
229
+ }
230
+
231
+ void cleanup() {
232
+ std::lock_guard<std::mutex> lock(mutex);
233
+ for (auto & bufs : free) {
234
+ bufs.host_buf.Destroy();
235
+ bufs.dev_buf.Destroy();
236
+ bufs.query_set.Destroy();
237
+ }
238
+ free.clear();
239
+ }
240
+
241
+ ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
242
+ };
243
+ #endif
244
+
245
+ struct webgpu_command {
246
+ wgpu::CommandBuffer commands;
247
+ std::vector<webgpu_pool_bufs> params_bufs;
248
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs;
249
+ #ifdef GGML_WEBGPU_GPU_PROFILE
250
+ webgpu_gpu_profile_bufs timestamp_query_bufs;
251
+ std::string pipeline_name;
252
+ #endif
253
+ };
254
+
255
+ struct webgpu_capabilities {
256
+ wgpu::Limits limits;
257
+ bool supports_subgroup_matrix = false;
258
+
259
+ uint32_t sg_mat_m = 0;
260
+ uint32_t sg_mat_n = 0;
261
+ uint32_t sg_mat_k = 0;
262
+
263
+ uint32_t subgroup_size = 0;
264
+ uint32_t max_subgroup_size = 0;
265
+ size_t memset_bytes_per_thread;
266
+ };
267
+
268
+ // Stores global webgpu members
269
+ struct webgpu_global_context_struct {
270
+ wgpu::Instance instance;
271
+ wgpu::Adapter adapter;
272
+ wgpu::Device device;
273
+ wgpu::Queue queue;
274
+
275
+ webgpu_capabilities capabilities;
276
+ // Shared buffer to move data from device to host
277
+ wgpu::Buffer get_tensor_staging_buf;
278
+ // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
279
+ std::recursive_mutex mutex;
280
+
281
+ webgpu_buf_pool memset_buf_pool;
282
+ std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
283
+ std::atomic_uint inflight_threads = 0;
284
+
285
+ #ifdef GGML_WEBGPU_CPU_PROFILE
286
+ // Profiling: labeled CPU time in ms (total)
287
+ std::unordered_map<std::string, double> cpu_time_ms;
288
+ // Profiling: detailed CPU time in ms
289
+ std::unordered_map<std::string, double> cpu_detail_ms;
290
+ #endif
291
+
292
+ #ifdef GGML_WEBGPU_GPU_PROFILE
293
+ // Profiling: per-shader GPU time in ms
294
+ std::unordered_map<std::string, double> shader_gpu_time_ms;
295
+ // Profiling: pool of timestamp query buffers (one per operation)
296
+ webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
297
+ #endif
298
+
299
+ #ifdef GGML_WEBGPU_DEBUG
300
+ wgpu::Buffer debug_host_buf;
301
+ wgpu::Buffer debug_dev_buf;
302
+ #endif
303
+
304
+ ~webgpu_global_context_struct() {
305
+ if (this->get_tensor_staging_buf) {
306
+ this->get_tensor_staging_buf.Destroy();
307
+ this->get_tensor_staging_buf = nullptr;
308
+ }
309
+ #ifdef GGML_WEBGPU_DEBUG
310
+ if (this->debug_host_buf) {
311
+ this->debug_host_buf.Destroy();
312
+ this->debug_host_buf = nullptr;
313
+ }
314
+ if (this->debug_dev_buf) {
315
+ this->debug_dev_buf.Destroy();
316
+ this->debug_dev_buf = nullptr;
317
+ }
318
+ #endif
319
+ }
320
+ };
321
+
322
+ typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
323
+
324
+ // All the base objects needed to run operations on a WebGPU device
325
+ struct webgpu_context_struct {
326
+ // Points to global instances owned by ggml_backend_webgpu_reg_context
327
+ webgpu_global_context global_ctx;
328
+
329
+ std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
330
+
331
+ webgpu_buf_pool param_buf_pool;
332
+ webgpu_buf_pool set_rows_error_buf_pool;
333
+
334
+ std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
335
+
336
+ std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
337
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
338
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
339
+
340
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
341
+
342
+ size_t memset_bytes_per_thread;
343
+ };
344
+
345
+ typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
346
+
347
+ // Metadata required for the ggml backend registration/discovery interface
348
+ struct ggml_backend_webgpu_reg_context {
349
+ // Since the Instance is a global entrypoint into the WebGPU API, it lives here
350
+ webgpu_global_context webgpu_global_ctx;
351
+ size_t device_count;
352
+ const char * name;
353
+ };
354
+
355
+ // Per-device struct for the global logical device interface
356
+ struct ggml_backend_webgpu_device_context {
357
+ webgpu_global_context webgpu_global_ctx;
358
+ std::string device_name;
359
+ std::string device_desc;
360
+ };
361
+
362
+ // Per-thread data required to actually run WebGPU operations in a backend instance
363
+ struct ggml_backend_webgpu_context {
364
+ webgpu_context webgpu_ctx;
365
+ std::string name;
366
+ };
367
+
368
+ // Per-thread data related to buffers
369
+ struct ggml_backend_webgpu_buffer_context {
370
+ wgpu::Buffer buffer;
371
+ std::string label;
372
+ webgpu_global_context global_ctx;
373
+
374
+ ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
375
+ buffer(std::move(buf)),
376
+ label(std::move(lbl)),
377
+ global_ctx(std::move(global_ctx_)) {}
378
+ };
379
+
380
+ /* WebGPU object initializations */
381
+
382
+ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
383
+ const char * shader_code,
384
+ const char * label,
385
+ const std::vector<wgpu::ConstantEntry> & constants = {}) {
386
+ wgpu::ShaderSourceWGSL shader_source;
387
+ shader_source.code = shader_code;
388
+
389
+ wgpu::ShaderModuleDescriptor shader_desc;
390
+ shader_desc.nextInChain = &shader_source;
391
+
392
+ wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
393
+
394
+ wgpu::ComputePipelineDescriptor pipeline_desc;
395
+ pipeline_desc.label = label;
396
+ pipeline_desc.compute.module = shader_module;
397
+ pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
398
+ pipeline_desc.layout = nullptr; // nullptr means auto layout
399
+ if (constants.size() > 0) {
400
+ pipeline_desc.compute.constants = constants.data();
401
+ pipeline_desc.compute.constantCount = constants.size();
402
+ }
403
+ return { device.CreateComputePipeline(&pipeline_desc), label };
404
+ }
405
+
406
+ static void ggml_webgpu_create_buffer(wgpu::Device & device,
407
+ wgpu::Buffer & buffer,
408
+ size_t size,
409
+ wgpu::BufferUsage usage,
410
+ const char * label) {
411
+ wgpu::BufferDescriptor buffer_desc;
412
+ buffer_desc.size = size;
413
+ buffer_desc.usage = usage;
414
+ buffer_desc.label = label;
415
+ buffer_desc.mappedAtCreation = false;
416
+
417
+ // TODO: error handling
418
+ buffer = device.CreateBuffer(&buffer_desc);
419
+ }
420
+
421
+ /** End WebGPU object initializations */
422
+
423
+ /** WebGPU Actions */
424
+
425
+ // Wait for the queue to finish processing all submitted work
426
+ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
427
+ std::vector<webgpu_submission_futures> & futures,
428
+ bool block = true) {
429
+ // If we have too many in-flight submissions, wait on the oldest one first. If
430
+ // there are many threads, inflight_max may be 0, meaning that we must wait on
431
+ // all futures.
432
+ uint64_t timeout_ms = block ? UINT64_MAX : 0;
433
+ uint32_t inflight_threads = ctx->inflight_threads;
434
+ uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
435
+ while (futures.size() >= inflight_max && futures.size() > 0) {
436
+ ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
437
+ futures.erase(futures.begin());
438
+ }
439
+ size_t i = 0;
440
+ while (i < futures.size()) {
441
+ auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
442
+ switch (waitStatus) {
443
+ case wgpu::WaitStatus::Success:
444
+ futures.erase(futures.begin() + i);
445
+ break;
446
+ case wgpu::WaitStatus::TimedOut:
447
+ i++;
448
+ break;
449
+ case wgpu::WaitStatus::Error:
450
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
451
+ break;
452
+ default:
453
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
454
+ break;
455
+ }
456
+ }
457
+ }
458
+
459
+ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
460
+ wgpu::Buffer & buffer,
461
+ wgpu::MapMode mode,
462
+ size_t offset,
463
+ size_t size) {
464
+ ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
465
+ [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
466
+ if (status != wgpu::MapAsyncStatus::Success) {
467
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
468
+ message.data);
469
+ }
470
+ }),
471
+ UINT64_MAX);
472
+ }
473
+
474
+ #ifdef GGML_WEBGPU_DEBUG
475
+ // This function adds debugging information to shaders, as WebGPU does not support printing directly.
476
+ // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
477
+ // debug statements in the shader, and then call this function after encoding the commands and submitting them.
478
+ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
479
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
480
+ encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
481
+ wgpu::CommandBuffer commands = encoder.Finish();
482
+ ctx->queue.Submit(1, &commands);
483
+ ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
484
+ const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
485
+ std::cout << "debug[0]: " << debug_data[0] << "\n";
486
+ ctx->debug_host_buf.Unmap();
487
+ }
488
+ #endif
489
+
490
+ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx,
491
+ std::vector<webgpu_command> commands,
492
+ webgpu_buf_pool & param_buf_pool,
493
+ webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
494
+ std::vector<wgpu::CommandBuffer> command_buffers;
495
+ std::vector<webgpu_pool_bufs> params_bufs;
496
+ std::vector<webgpu_pool_bufs> set_rows_error_bufs;
497
+ #ifdef GGML_WEBGPU_GPU_PROFILE
498
+ std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
499
+ #endif
500
+
501
+ for (const auto & command : commands) {
502
+ command_buffers.push_back(command.commands);
503
+ params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
504
+ if (command.set_rows_error_bufs) {
505
+ set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
506
+ }
507
+ }
508
+ ctx->queue.Submit(command_buffers.size(), command_buffers.data());
509
+
510
+ std::vector<wgpu::FutureWaitInfo> futures;
511
+
512
+ wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
513
+ wgpu::CallbackMode::AllowSpontaneous,
514
+ [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
515
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
516
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
517
+ }
518
+ // Free the staged buffers
519
+ param_buf_pool.free_bufs(params_bufs);
520
+ });
521
+ futures.push_back({ p_f });
522
+
523
+ for (const auto & bufs : set_rows_error_bufs) {
524
+ wgpu::Future f = bufs.host_buf.MapAsync(
525
+ wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
526
+ [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
527
+ if (status != wgpu::MapAsyncStatus::Success) {
528
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
529
+ } else {
530
+ const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
531
+ if (*error_data) {
532
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
533
+ }
534
+ // We can't unmap in here due to WebGPU reentrancy limitations.
535
+ if (set_rows_error_buf_pool) {
536
+ set_rows_error_buf_pool->free_bufs({ bufs });
537
+ }
538
+ }
539
+ });
540
+ futures.push_back({ f });
541
+ }
542
+
543
+ #ifdef GGML_WEBGPU_GPU_PROFILE
544
+ for (const auto & command : commands) {
545
+ auto label = command.pipeline_name;
546
+ auto ts_bufs = command.timestamp_query_bufs;
547
+
548
+ wgpu::Future f = ts_bufs.host_buf.MapAsync(
549
+ wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
550
+ [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
551
+ if (status != wgpu::MapAsyncStatus::Success) {
552
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
553
+ } else {
554
+ const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
555
+ // WebGPU timestamps are in ns; convert to ms
556
+ double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
557
+ ctx->shader_gpu_time_ms[label] += elapsed_ms;
558
+ // We can't unmap in here due to WebGPU reentrancy limitations.
559
+ ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
560
+ }
561
+ });
562
+ futures.push_back({ f });
563
+ }
564
+ #endif
565
+ return { futures };
566
+ }
567
+
568
+ static webgpu_command ggml_backend_webgpu_build_multi(
569
+ webgpu_global_context & ctx,
570
+ webgpu_buf_pool & param_buf_pool,
571
+ const std::vector<webgpu_pipeline> & pipelines,
572
+ const std::vector<std::vector<uint32_t>> & params_list,
573
+ const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
574
+ const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
575
+ const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
576
+ GGML_ASSERT(pipelines.size() == params_list.size());
577
+ GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
578
+ GGML_ASSERT(pipelines.size() == workgroups_list.size());
579
+
580
+ std::vector<webgpu_pool_bufs> params_bufs_list;
581
+ std::vector<wgpu::BindGroup> bind_groups;
582
+
583
+ for (size_t i = 0; i < pipelines.size(); i++) {
584
+ webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
585
+
586
+ ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
587
+ params_bufs.host_buf.GetSize());
588
+ uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
589
+ for (size_t j = 0; j < params_list[i].size(); j++) {
590
+ _params[j] = params_list[i][j];
591
+ }
592
+ params_bufs.host_buf.Unmap();
593
+
594
+ std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
595
+ uint32_t params_binding_num = entries.size();
596
+ entries.push_back({ .binding = params_binding_num,
597
+ .buffer = params_bufs.dev_buf,
598
+ .offset = 0,
599
+ .size = params_bufs.dev_buf.GetSize() });
600
+
601
+ wgpu::BindGroupDescriptor bind_group_desc;
602
+ bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
603
+ bind_group_desc.entryCount = entries.size();
604
+ bind_group_desc.entries = entries.data();
605
+ bind_group_desc.label = pipelines[i].name.c_str();
606
+ bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
607
+
608
+ params_bufs_list.push_back(params_bufs);
609
+ }
610
+
611
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
612
+ for (const auto & params_bufs : params_bufs_list) {
613
+ encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
614
+ }
615
+
616
+ // If there are SET_ROWS operations in this submission, copy their error
617
+ // buffers to the host.
618
+ if (set_rows_error_bufs) {
619
+ encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
620
+ set_rows_error_bufs->host_buf.GetSize());
621
+ }
622
+
623
+ #ifdef GGML_WEBGPU_GPU_PROFILE
624
+ webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
625
+ if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
626
+ ts_bufs.host_buf.Unmap();
627
+ }
628
+
629
+ wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
630
+ .beginningOfPassWriteIndex = 0,
631
+ .endOfPassWriteIndex = 1 };
632
+ wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
633
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
634
+ #else
635
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
636
+ #endif
637
+ for (size_t i = 0; i < pipelines.size(); i++) {
638
+ pass.SetPipeline(pipelines[i].pipeline);
639
+ pass.SetBindGroup(0, bind_groups[i]);
640
+ pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
641
+ }
642
+ pass.End();
643
+
644
+ #ifdef GGML_WEBGPU_GPU_PROFILE
645
+ encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
646
+ encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
647
+ #endif
648
+
649
+ wgpu::CommandBuffer commands = encoder.Finish();
650
+ webgpu_command result = {};
651
+ result.commands = commands;
652
+ result.params_bufs = params_bufs_list;
653
+ result.set_rows_error_bufs = set_rows_error_bufs;
654
+ #ifdef GGML_WEBGPU_GPU_PROFILE
655
+ result.timestamp_query_bufs = ts_bufs;
656
+ // TODO: handle multiple pipeline names
657
+ result.pipeline_name = pipelines.front().name;
658
+ #endif
659
+ return result;
660
+ }
661
+
662
+ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx,
663
+ webgpu_buf_pool & param_buf_pool,
664
+ webgpu_pipeline & pipeline,
665
+ std::vector<uint32_t> params,
666
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
667
+ uint32_t wg_x,
668
+ uint32_t wg_y = 1,
669
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
670
+ return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
671
+ {
672
+ pipeline
673
+ },
674
+ { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
675
+ }
676
+
677
+ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
678
+ wgpu::Buffer & buf,
679
+ uint32_t value,
680
+ size_t offset,
681
+ size_t size) {
682
+ std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
683
+ std::vector<wgpu::BindGroupEntry> entries = {
684
+ { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
685
+ };
686
+ size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
687
+ uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
688
+
689
+ webgpu_command command =
690
+ ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
691
+ std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
692
+ ctx->memset_buf_pool) };
693
+ ggml_backend_webgpu_wait(ctx, futures);
694
+ }
695
+
696
+ /** End WebGPU Actions */
697
+
698
+ /** GGML Backend Interface */
699
+
700
+ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
701
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
702
+ return ctx->name.c_str();
703
+ }
704
+
705
+ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
706
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
707
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
708
+
709
+ #ifdef GGML_WEBGPU_CPU_PROFILE
710
+ std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
711
+ double total_cpu = 0.0;
712
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
713
+ total_cpu += kv.second;
714
+ }
715
+ std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
716
+ std::cout << "ggml_webgpu: cpu breakdown:\n";
717
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
718
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
719
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
720
+ }
721
+ if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
722
+ std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
723
+ }
724
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
725
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
726
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
727
+ }
728
+ #endif
729
+
730
+ #ifdef GGML_WEBGPU_GPU_PROFILE
731
+ std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
732
+ double total_gpu = 0.0;
733
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
734
+ total_gpu += kv.second;
735
+ }
736
+ std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
737
+ std::cout << "\nggml_webgpu: gpu breakdown:\n";
738
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
739
+ double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
740
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
741
+ }
742
+ #endif
743
+
744
+ #if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
745
+ std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
746
+ #endif
747
+
748
+ delete ctx;
749
+ delete backend;
750
+ }
751
+
752
+ static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
753
+ return webgpu_tensor_offset(tensor) + tensor->view_offs;
754
+ }
755
+
756
+ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
757
+ ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
758
+ return ctx->buffer;
759
+ }
760
+
761
+ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
762
+ size_t offset = ggml_webgpu_tensor_offset(t);
763
+ return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
764
+ }
765
+
766
+ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
767
+ size_t offset = ggml_webgpu_tensor_offset(t);
768
+ return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
769
+ }
770
+
771
+ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
772
+ return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
773
+ }
774
+
775
+ // Used to determine if two tensors are the same for in-place operations
776
+ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
777
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
778
+ (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
779
+ }
780
+
781
+ // Used to determine if two tensors share the same buffer and their byte ranges overlap,
782
+ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
783
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
784
+ ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
785
+ ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
786
+ }
787
+
788
+ struct binary_overlap_flags {
789
+ bool inplace; // src0 == dst
790
+ bool overlap; // src1 == dst
791
+ };
792
+
793
+ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
794
+ ggml_tensor * src1,
795
+ ggml_tensor * dst) {
796
+ binary_overlap_flags flags = {};
797
+ flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
798
+ flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
799
+
800
+ return flags;
801
+ }
802
+
803
+ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
804
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
805
+
806
+ std::vector<uint32_t> params = {
807
+ ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
808
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
809
+ // Convert byte-strides to element-strides
810
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
811
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
812
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
813
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
814
+ // Logical shapes
815
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
816
+ (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
817
+ };
818
+
819
+ std::vector<wgpu::BindGroupEntry> entries = {
820
+ { .binding = 0,
821
+ .buffer = ggml_webgpu_tensor_buf(src),
822
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
823
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
824
+ { .binding = 1,
825
+ .buffer = ggml_webgpu_tensor_buf(dst),
826
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
827
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
828
+ };
829
+
830
+ uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
831
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
832
+ params, entries, wg_x);
833
+ }
834
+
835
+ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
836
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
837
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
838
+ };
839
+
840
+ webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
841
+
842
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
843
+
844
+ const uint32_t ne = (uint32_t) ggml_nelements(dst);
845
+
846
+ std::vector<uint32_t> params = {
847
+ ne,
848
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
849
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
850
+ // Strides (in elements)
851
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
852
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
853
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
854
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
855
+ // Shapes
856
+ (uint32_t) src->ne[0],
857
+ (uint32_t) src->ne[1],
858
+ (uint32_t) src->ne[2],
859
+ (uint32_t) src->ne[3],
860
+ (uint32_t) dst->ne[0],
861
+ (uint32_t) dst->ne[1],
862
+ (uint32_t) dst->ne[2],
863
+ (uint32_t) dst->ne[3],
864
+ // Pad sizes
865
+ (uint32_t) ggml_get_op_params_i32(dst, 0),
866
+ (uint32_t) ggml_get_op_params_i32(dst, 1),
867
+ (uint32_t) ggml_get_op_params_i32(dst, 2),
868
+ (uint32_t) ggml_get_op_params_i32(dst, 3),
869
+ (uint32_t) ggml_get_op_params_i32(dst, 4),
870
+ (uint32_t) ggml_get_op_params_i32(dst, 5),
871
+ (uint32_t) ggml_get_op_params_i32(dst, 6),
872
+ (uint32_t) ggml_get_op_params_i32(dst, 7),
873
+ };
874
+
875
+ std::vector<wgpu::BindGroupEntry> entries = {
876
+ { .binding = 0,
877
+ .buffer = ggml_webgpu_tensor_buf(src),
878
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
879
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
880
+ { .binding = 1,
881
+ .buffer = ggml_webgpu_tensor_buf(dst),
882
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
883
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
884
+ };
885
+
886
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
887
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
888
+ }
889
+
890
+ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
891
+ ggml_tensor * src,
892
+ ggml_tensor * idx,
893
+ ggml_tensor * dst) {
894
+ // For set rows specifically, we need to check if src and idx are empty
895
+ // tensors.
896
+ if (ggml_is_empty(src) || ggml_is_empty(idx)) {
897
+ return std::nullopt;
898
+ }
899
+
900
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
901
+ .src0 = src,
902
+ .src1 = idx,
903
+ .dst = dst,
904
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
905
+ };
906
+
907
+ webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
908
+
909
+ auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
910
+
911
+ std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
912
+ if (decisions->i64_idx) {
913
+ error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
914
+ if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
915
+ error_bufs->host_buf.Unmap();
916
+ }
917
+ }
918
+
919
+ std::vector<uint32_t> params = {
920
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
921
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
922
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
923
+ // Convert byte-strides to element-strides
924
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
925
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
926
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
927
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
928
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
929
+ // Shape of src
930
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
931
+ // Shape of idx
932
+ (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
933
+ };
934
+
935
+ std::vector<wgpu::BindGroupEntry> entries = {
936
+ { .binding = 0,
937
+ .buffer = ggml_webgpu_tensor_buf(src),
938
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
939
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
940
+ { .binding = 1,
941
+ .buffer = ggml_webgpu_tensor_buf(idx),
942
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
943
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
944
+ { .binding = 2,
945
+ .buffer = ggml_webgpu_tensor_buf(dst),
946
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
947
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
948
+ };
949
+
950
+ if (decisions->i64_idx) {
951
+ entries.push_back(
952
+ { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
953
+ }
954
+
955
+ uint32_t threads;
956
+ if (decisions->vec4) {
957
+ threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
958
+ } else {
959
+ threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
960
+ }
961
+ uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
962
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
963
+ error_bufs);
964
+ }
965
+
966
+ // Workgroup size is a common constant
967
+ static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
968
+ std::vector<wgpu::ConstantEntry> constants(1);
969
+ constants[0].key = "wg_size";
970
+ constants[0].value = wg_size;
971
+ return constants;
972
+ }
973
+
974
+ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
975
+ ggml_tensor * src,
976
+ ggml_tensor * idx,
977
+ ggml_tensor * dst) {
978
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
979
+ .src0 = src,
980
+ .src1 = nullptr,
981
+ .dst = dst,
982
+ .max_wg_size = WEBGPU_MAX_WG_SIZE,
983
+ };
984
+
985
+ webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
986
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
987
+
988
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
989
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
990
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
991
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
992
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
993
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
994
+ (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
995
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
996
+ (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
997
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
998
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
999
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1000
+ (uint32_t) dst->ne[0],
1001
+ (uint32_t) dst->ne[1],
1002
+ (uint32_t) dst->ne[2],
1003
+ (uint32_t) dst->ne[3],
1004
+ (uint32_t) (idx->ne[1]),
1005
+ (uint32_t) (idx->ne[2]) };
1006
+
1007
+ std::vector<wgpu::BindGroupEntry> entries = {
1008
+ { .binding = 0,
1009
+ .buffer = ggml_webgpu_tensor_buf(src),
1010
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1011
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1012
+ { .binding = 1,
1013
+ .buffer = ggml_webgpu_tensor_buf(idx),
1014
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
1015
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
1016
+ { .binding = 2,
1017
+ .buffer = ggml_webgpu_tensor_buf(dst),
1018
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1019
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1020
+ };
1021
+
1022
+ uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
1023
+
1024
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1025
+ }
1026
+
1027
+ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
1028
+ ggml_tensor * src0,
1029
+ ggml_tensor * src1,
1030
+ ggml_tensor * dst) {
1031
+ // Determine if this is a mat-vec operation
1032
+ bool is_vec = (dst->ne[1] == 1);
1033
+
1034
+ // Determine if we should use fast path
1035
+ bool use_fast = false;
1036
+ switch (src1->type) {
1037
+ case GGML_TYPE_F16:
1038
+ use_fast = (src0->type == GGML_TYPE_F16);
1039
+ break;
1040
+ case GGML_TYPE_F32:
1041
+ switch (src0->type) {
1042
+ case GGML_TYPE_F32:
1043
+ case GGML_TYPE_F16:
1044
+ case GGML_TYPE_Q4_0:
1045
+ use_fast = true;
1046
+ break;
1047
+ default:
1048
+ break;
1049
+ }
1050
+ break;
1051
+ default:
1052
+ break;
1053
+ }
1054
+
1055
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1056
+ .src0 = src0,
1057
+ .src1 = src1,
1058
+ .dst = dst,
1059
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1060
+ .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
1061
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1062
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1063
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1064
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
1065
+ };
1066
+
1067
+ // Get or create pipeline
1068
+ webgpu_pipeline pipeline;
1069
+
1070
+ if (use_fast && is_vec) {
1071
+ pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
1072
+ } else if (use_fast) {
1073
+ pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
1074
+ } else {
1075
+ pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
1076
+ }
1077
+
1078
+ // Build params
1079
+ std::vector<uint32_t> params = {
1080
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1081
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1082
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1083
+ (uint32_t) dst->ne[0],
1084
+ (uint32_t) dst->ne[1],
1085
+ (uint32_t) src0->ne[0],
1086
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1087
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1088
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1089
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1090
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1091
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1092
+ (uint32_t) src0->ne[2],
1093
+ (uint32_t) src0->ne[3],
1094
+ (uint32_t) (src1->ne[2] / src0->ne[2]),
1095
+ (uint32_t) (src1->ne[3] / src0->ne[3])
1096
+ };
1097
+
1098
+ // Build bind group entries
1099
+ std::vector<wgpu::BindGroupEntry> entries = {
1100
+ { .binding = 0,
1101
+ .buffer = ggml_webgpu_tensor_buf(src0),
1102
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1103
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1104
+ { .binding = 1,
1105
+ .buffer = ggml_webgpu_tensor_buf(src1),
1106
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1107
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
1108
+ { .binding = 2,
1109
+ .buffer = ggml_webgpu_tensor_buf(dst),
1110
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1111
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
1112
+ };
1113
+
1114
+ // Calculate workgroup dimensions
1115
+ uint32_t wg_x = 1;
1116
+ uint32_t wg_y = 1;
1117
+
1118
+ if (use_fast && is_vec) {
1119
+ auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
1120
+
1121
+ uint32_t batches = dst->ne[2] * dst->ne[3];
1122
+ uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
1123
+ uint32_t total_wg = output_groups * batches;
1124
+ // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
1125
+ wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
1126
+ wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
1127
+ } else if (use_fast) {
1128
+ auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
1129
+
1130
+ // Fast-path tiled/subgroup calculations
1131
+ uint32_t wg_m, wg_n;
1132
+ if (decisions->use_subgroup_matrix) {
1133
+ uint32_t wg_m_sg_tile =
1134
+ decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
1135
+ wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1136
+ uint32_t wg_n_sg_tile =
1137
+ decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
1138
+ wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1139
+ } else {
1140
+ uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
1141
+ uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
1142
+ wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1143
+ wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
1144
+ }
1145
+ wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1146
+ } else { // legacy
1147
+ auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1148
+ uint32_t wg_size = decisions->wg_size;
1149
+ wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
1150
+ wg_y = 1;
1151
+ }
1152
+
1153
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
1154
+ }
1155
+
1156
+ #ifndef __EMSCRIPTEN__
1157
+ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1158
+ ggml_tensor * Q,
1159
+ ggml_tensor * K,
1160
+ ggml_tensor * V,
1161
+ ggml_tensor * mask,
1162
+ ggml_tensor * sinks,
1163
+ ggml_tensor * dst) {
1164
+ float scale = *(float *) dst->op_params;
1165
+ float max_bias;
1166
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1167
+ float logit_softcap;
1168
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
1169
+ if (logit_softcap != 0.0f) {
1170
+ scale /= logit_softcap;
1171
+ }
1172
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
1173
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1174
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1175
+
1176
+ const int has_mask = (mask != nullptr);
1177
+ const int has_sinks = (sinks != nullptr);
1178
+
1179
+ std::vector<uint32_t> params = {
1180
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
1181
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
1182
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
1183
+ has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
1184
+ has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
1185
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1186
+ (uint32_t) Q->ne[2], // number of heads
1187
+ (uint32_t) Q->ne[1], // sequence length (Q)
1188
+ (uint32_t) K->ne[1], // sequence length (K/V)
1189
+ (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
1190
+ (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
1191
+ (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
1192
+ (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
1193
+ (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
1194
+ (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
1195
+ (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
1196
+ (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
1197
+ (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
1198
+ has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
1199
+ (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
1200
+ *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
1201
+ *(uint32_t *) &max_bias,
1202
+ *(uint32_t *) &logit_softcap,
1203
+ *(uint32_t *) &n_head_log2,
1204
+ *(uint32_t *) &m0,
1205
+ *(uint32_t *) &m1
1206
+
1207
+ };
1208
+ std::vector<wgpu::BindGroupEntry> entries = {
1209
+ { .binding = 0,
1210
+ .buffer = ggml_webgpu_tensor_buf(Q),
1211
+ .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
1212
+ .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
1213
+ { .binding = 1,
1214
+ .buffer = ggml_webgpu_tensor_buf(K),
1215
+ .offset = ggml_webgpu_tensor_align_offset(ctx, K),
1216
+ .size = ggml_webgpu_tensor_binding_size(ctx, K) },
1217
+ { .binding = 2,
1218
+ .buffer = ggml_webgpu_tensor_buf(V),
1219
+ .offset = ggml_webgpu_tensor_align_offset(ctx, V),
1220
+ .size = ggml_webgpu_tensor_binding_size(ctx, V) }
1221
+ };
1222
+ uint32_t binding_index = 3;
1223
+ if (has_mask) {
1224
+ entries.push_back({ .binding = binding_index++,
1225
+ .buffer = ggml_webgpu_tensor_buf(mask),
1226
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
1227
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
1228
+ }
1229
+ if (has_sinks) {
1230
+ entries.push_back({ .binding = binding_index++,
1231
+ .buffer = ggml_webgpu_tensor_buf(sinks),
1232
+ .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
1233
+ .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
1234
+ }
1235
+ entries.push_back({ .binding = binding_index++,
1236
+ .buffer = ggml_webgpu_tensor_buf(dst),
1237
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1238
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1239
+
1240
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1241
+ .src0 = Q,
1242
+ .src1 = K,
1243
+ .src2 = V,
1244
+ .src3 = mask,
1245
+ .src4 = sinks,
1246
+ .dst = dst,
1247
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1248
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1249
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1250
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1251
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1252
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
1253
+ };
1254
+
1255
+ webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
1256
+
1257
+ auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
1258
+
1259
+ uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
1260
+ uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
1261
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1262
+ }
1263
+ #endif
1264
+
1265
+ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1266
+ bool is_unary = dst->op == GGML_OP_UNARY;
1267
+ bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
1268
+
1269
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1270
+ .src0 = src,
1271
+ .src1 = nullptr,
1272
+ .dst = dst,
1273
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1274
+ .inplace = inplace,
1275
+ };
1276
+
1277
+ webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
1278
+
1279
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1280
+
1281
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1282
+
1283
+ std::vector<uint32_t> params = { ne,
1284
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1285
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1286
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
1287
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1288
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1289
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1290
+ (uint32_t) src->ne[0],
1291
+ (uint32_t) src->ne[1],
1292
+ (uint32_t) src->ne[2] };
1293
+
1294
+ ggml_tensor * effective_src = src;
1295
+ if (is_unary) {
1296
+ ggml_unary_op unary_op = ggml_get_unary_op(dst);
1297
+ switch (unary_op) {
1298
+ case GGML_UNARY_OP_XIELU:
1299
+ {
1300
+ // Get float parameters and reinterpret their bit patterns as uint32_t
1301
+ // for passing through the params buffer
1302
+ float alpha_n = ggml_get_op_params_f32(dst, 1);
1303
+ float alpha_p = ggml_get_op_params_f32(dst, 2);
1304
+ float beta = ggml_get_op_params_f32(dst, 3);
1305
+ float eps = ggml_get_op_params_f32(dst, 4);
1306
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1307
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1308
+ params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1309
+ params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1310
+ break;
1311
+ }
1312
+ default:
1313
+ break;
1314
+ }
1315
+ } else if (dst->op == GGML_OP_CLAMP) {
1316
+ float clamp_min = ggml_get_op_params_f32(dst, 0);
1317
+ float clamp_max = ggml_get_op_params_f32(dst, 1);
1318
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
1319
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
1320
+ } else if (dst->op == GGML_OP_FILL) {
1321
+ float fill_val = ggml_get_op_params_f32(dst, 0);
1322
+ params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
1323
+ effective_src = dst; // fill simply fills dst
1324
+ }
1325
+
1326
+ std::vector<wgpu::BindGroupEntry> entries = {
1327
+ { .binding = 0,
1328
+ .buffer = ggml_webgpu_tensor_buf(effective_src),
1329
+ .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src),
1330
+ .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
1331
+ };
1332
+ if (!inplace) {
1333
+ entries.push_back({ .binding = 1,
1334
+ .buffer = ggml_webgpu_tensor_buf(dst),
1335
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1336
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1337
+ }
1338
+
1339
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1340
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1341
+ }
1342
+
1343
+ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1344
+ ggml_tensor * src0,
1345
+ ggml_tensor * src1,
1346
+ ggml_tensor * dst) {
1347
+ binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
1348
+
1349
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1350
+ .src0 = src0,
1351
+ .src1 = src1,
1352
+ .dst = dst,
1353
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1354
+ .inplace = flags.inplace,
1355
+ .overlap = flags.overlap,
1356
+ };
1357
+
1358
+ webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
1359
+
1360
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1361
+
1362
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
1363
+
1364
+ std::vector<uint32_t> params = {
1365
+ ne,
1366
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1367
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1368
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1369
+ (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
1370
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1371
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1372
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1373
+ (uint32_t) src0->ne[0],
1374
+ (uint32_t) src0->ne[1],
1375
+ (uint32_t) src0->ne[2],
1376
+ (uint32_t) src1->ne[0],
1377
+ (uint32_t) src1->ne[1],
1378
+ (uint32_t) src1->ne[2],
1379
+ (uint32_t) src1->ne[3],
1380
+ };
1381
+
1382
+ std::vector<wgpu::BindGroupEntry> entries;
1383
+
1384
+ entries.push_back({
1385
+ .binding = 0,
1386
+ .buffer = ggml_webgpu_tensor_buf(src0),
1387
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1388
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0),
1389
+ });
1390
+
1391
+ entries.push_back({
1392
+ .binding = 1,
1393
+ .buffer = ggml_webgpu_tensor_buf(src1),
1394
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1395
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1),
1396
+ });
1397
+
1398
+ if (!flags.inplace && !flags.overlap) {
1399
+ entries.push_back({ .binding = 2,
1400
+ .buffer = ggml_webgpu_tensor_buf(dst),
1401
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1402
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1403
+ }
1404
+
1405
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1406
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1407
+ }
1408
+
1409
+ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1410
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
1411
+
1412
+ std::vector<uint32_t> params = {
1413
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1414
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1415
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1416
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1417
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1418
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1419
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1420
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1421
+ (uint32_t) src->ne[0],
1422
+ (uint32_t) src->ne[1],
1423
+ (uint32_t) src->ne[2],
1424
+ (uint32_t) src->ne[3],
1425
+ *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
1426
+ };
1427
+
1428
+ std::vector<wgpu::BindGroupEntry> entries = {
1429
+ { .binding = 0,
1430
+ .buffer = ggml_webgpu_tensor_buf(src),
1431
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1432
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1433
+ };
1434
+ if (!inplace) {
1435
+ entries.push_back({ .binding = 1,
1436
+ .buffer = ggml_webgpu_tensor_buf(dst),
1437
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1438
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1439
+ }
1440
+
1441
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
1442
+ entries, ggml_nrows(src));
1443
+ }
1444
+
1445
+ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1446
+ ggml_tensor * src0,
1447
+ ggml_tensor * src1,
1448
+ ggml_tensor * src2,
1449
+ ggml_tensor * dst) {
1450
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1451
+ const int has_freq_factor = (src2 != nullptr);
1452
+
1453
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1454
+ const int mode = ((int32_t *) dst->op_params)[2];
1455
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1456
+
1457
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1458
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1459
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1460
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1461
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1462
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1463
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1464
+
1465
+ int sections[4];
1466
+ memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
1467
+
1468
+ float theta_scale = powf(freq_base, -2.0f / n_dims);
1469
+
1470
+ float corr_dims[2];
1471
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1472
+
1473
+ std::vector<uint32_t> params = {
1474
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1475
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1476
+ src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1477
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1478
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1479
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1480
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1481
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1482
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1483
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1484
+ (uint32_t) ggml_nelements(src0) / 2,
1485
+ (uint32_t) src0->ne[0],
1486
+ (uint32_t) src0->ne[1],
1487
+ (uint32_t) src0->ne[2],
1488
+ (uint32_t) n_dims,
1489
+ (uint32_t) mode,
1490
+ *(uint32_t *) &theta_scale,
1491
+ *(uint32_t *) &attn_factor,
1492
+ *(uint32_t *) &freq_scale,
1493
+ *(uint32_t *) &ext_factor,
1494
+ *(uint32_t *) &corr_dims[0],
1495
+ *(uint32_t *) &corr_dims[1],
1496
+ (uint32_t) sections[0],
1497
+ (uint32_t) sections[1],
1498
+ (uint32_t) sections[2],
1499
+ (uint32_t) sections[3]
1500
+ };
1501
+
1502
+ std::vector<wgpu::BindGroupEntry> entries = {
1503
+ { .binding = 0,
1504
+ .buffer = ggml_webgpu_tensor_buf(src0),
1505
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1506
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1507
+ { .binding = 1,
1508
+ .buffer = ggml_webgpu_tensor_buf(src1),
1509
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1510
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
1511
+ };
1512
+ uint32_t dst_binding = 2;
1513
+ if (has_freq_factor) {
1514
+ dst_binding = 3;
1515
+ entries.push_back({ .binding = 2,
1516
+ .buffer = ggml_webgpu_tensor_buf(src2),
1517
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1518
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1519
+ }
1520
+ if (!inplace) {
1521
+ entries.push_back({ .binding = dst_binding,
1522
+ .buffer = ggml_webgpu_tensor_buf(dst),
1523
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1524
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1525
+ }
1526
+
1527
+ webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
1528
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1529
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1530
+ }
1531
+
1532
+ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
1533
+ const int split = (src1 != nullptr);
1534
+
1535
+ std::vector<uint32_t> params = {
1536
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1537
+ src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1538
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1539
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1540
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1541
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1542
+ src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
1543
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1544
+ src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
1545
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1546
+ src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
1547
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1548
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1549
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1550
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1551
+ (uint32_t) ggml_nelements(dst),
1552
+ (uint32_t) dst->ne[0],
1553
+ (uint32_t) dst->ne[1],
1554
+ (uint32_t) dst->ne[2],
1555
+ (uint32_t) ((int32_t *) dst->op_params)[1], // swapped
1556
+ *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
1557
+ *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
1558
+ };
1559
+
1560
+ std::vector<wgpu::BindGroupEntry> entries = {
1561
+ { .binding = 0,
1562
+ .buffer = ggml_webgpu_tensor_buf(src0),
1563
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1564
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1565
+ };
1566
+ uint32_t dst_binding = 1;
1567
+ if (split) {
1568
+ dst_binding = 2;
1569
+ entries.push_back({ .binding = 1,
1570
+ .buffer = ggml_webgpu_tensor_buf(src1),
1571
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1572
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1573
+ }
1574
+ entries.push_back({ .binding = dst_binding,
1575
+ .buffer = ggml_webgpu_tensor_buf(dst),
1576
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1577
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1578
+
1579
+ webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
1580
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1581
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1582
+ }
1583
+
1584
+ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1585
+ bool inplace = ggml_webgpu_tensor_equal(src, dst);
1586
+
1587
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1588
+ .src0 = src,
1589
+ .src1 = nullptr,
1590
+ .dst = dst,
1591
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1592
+ .inplace = inplace,
1593
+ };
1594
+
1595
+ webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
1596
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1597
+
1598
+ // params unchanged
1599
+ std::vector<uint32_t> params = {
1600
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1601
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1602
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1603
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1604
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1605
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1606
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1607
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1608
+ (uint32_t) ggml_nelements(dst),
1609
+ (uint32_t) src->ne[0],
1610
+ (uint32_t) src->ne[1],
1611
+ (uint32_t) src->ne[2],
1612
+ *(uint32_t *) dst->op_params, // scale
1613
+ *(uint32_t *) &dst->op_params[1] // bias
1614
+ };
1615
+
1616
+ // bindgroups unchanged
1617
+ std::vector<wgpu::BindGroupEntry> entries = {
1618
+ { .binding = 0,
1619
+ .buffer = ggml_webgpu_tensor_buf(src),
1620
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1621
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1622
+ };
1623
+
1624
+ if (!inplace) {
1625
+ entries.push_back({ .binding = 1,
1626
+ .buffer = ggml_webgpu_tensor_buf(dst),
1627
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1628
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1629
+ }
1630
+
1631
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
1632
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1633
+ }
1634
+
1635
+ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
1636
+ ggml_tensor * src0,
1637
+ ggml_tensor * src1,
1638
+ ggml_tensor * src2,
1639
+ ggml_tensor * dst) {
1640
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1641
+ const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
1642
+ const int has_sink = (src2 != nullptr);
1643
+ float max_bias;
1644
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1645
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
1646
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1647
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1648
+
1649
+ std::vector<uint32_t> params = {
1650
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1651
+ mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1652
+ has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1653
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1654
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1655
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1656
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1657
+ mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
1658
+ mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
1659
+ mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
1660
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1661
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1662
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1663
+ (uint32_t) ggml_nelements(dst),
1664
+ (uint32_t) src0->ne[0],
1665
+ (uint32_t) src0->ne[1],
1666
+ (uint32_t) src0->ne[2],
1667
+ mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
1668
+ mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
1669
+ *(uint32_t *) dst->op_params, // scale
1670
+ *(uint32_t *) &max_bias,
1671
+ *(uint32_t *) &n_head_log2,
1672
+ *(uint32_t *) &m0,
1673
+ *(uint32_t *) &m1
1674
+ };
1675
+
1676
+ std::vector<wgpu::BindGroupEntry> entries = {
1677
+ { .binding = 0,
1678
+ .buffer = ggml_webgpu_tensor_buf(src0),
1679
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1680
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
1681
+ };
1682
+ uint32_t binding_num = 1;
1683
+ if (mask_type < 2) {
1684
+ entries.push_back({ .binding = binding_num,
1685
+ .buffer = ggml_webgpu_tensor_buf(src1),
1686
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1687
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1688
+ binding_num++;
1689
+ }
1690
+ if (has_sink) {
1691
+ entries.push_back({ .binding = binding_num,
1692
+ .buffer = ggml_webgpu_tensor_buf(src2),
1693
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1694
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1695
+ binding_num++;
1696
+ }
1697
+ if (!inplace) {
1698
+ entries.push_back({ .binding = binding_num,
1699
+ .buffer = ggml_webgpu_tensor_buf(dst),
1700
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1701
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1702
+ }
1703
+
1704
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
1705
+ ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1706
+ ggml_nrows(dst));
1707
+ }
1708
+
1709
+ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1710
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1711
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1712
+ (uint32_t) src->ne[0] };
1713
+
1714
+ std::vector<wgpu::BindGroupEntry> entries = {
1715
+ { .binding = 0,
1716
+ .buffer = ggml_webgpu_tensor_buf(src),
1717
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1718
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1719
+ { .binding = 1,
1720
+ .buffer = ggml_webgpu_tensor_buf(dst),
1721
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1722
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1723
+ };
1724
+
1725
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1726
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1727
+ };
1728
+
1729
+ webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
1730
+ uint32_t wg_x = ggml_nelements(dst);
1731
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1732
+ }
1733
+
1734
+ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1735
+ bool is_top_k = dst->op == GGML_OP_TOP_K;
1736
+
1737
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1738
+ .src0 = src,
1739
+ .src1 = nullptr,
1740
+ .dst = dst,
1741
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1742
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1743
+ };
1744
+
1745
+ webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
1746
+ auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
1747
+
1748
+ webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
1749
+
1750
+ const uint32_t src_ne0 = (uint32_t) src->ne[0];
1751
+ const uint32_t nrows = (uint32_t) ggml_nrows(src);
1752
+ const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
1753
+ const uint32_t block_size =
1754
+ is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
1755
+ uint32_t out_ne0 = src_ne0;
1756
+ if (is_top_k) {
1757
+ if (npr > 1) {
1758
+ const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
1759
+ out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
1760
+ } else {
1761
+ out_ne0 = block_size;
1762
+ }
1763
+ }
1764
+
1765
+ uint32_t merge_len = block_size;
1766
+ uint32_t merge_passes = 0;
1767
+ while (merge_len < out_ne0) {
1768
+ merge_len <<= 1;
1769
+ merge_passes++;
1770
+ }
1771
+
1772
+ const bool start_in_tmp = (merge_passes % 2) == 1;
1773
+
1774
+ const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
1775
+ const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
1776
+ const size_t tmp_offset =
1777
+ ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1778
+ const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
1779
+ const size_t dst_binding_size =
1780
+ ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
1781
+
1782
+ const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
1783
+ const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
1784
+ const uint32_t offset_tmp = 0;
1785
+ const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
1786
+ const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
1787
+ const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
1788
+ const uint32_t stride_idx1 = out_ne0;
1789
+ const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
1790
+ const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
1791
+
1792
+ std::vector<webgpu_pipeline> pipelines;
1793
+ std::vector<std::vector<uint32_t>> params_list;
1794
+ std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
1795
+ std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
1796
+
1797
+ const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst;
1798
+ const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1799
+ const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
1800
+
1801
+ std::vector<uint32_t> init_params = {
1802
+ offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1,
1803
+ stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
1804
+ block_size, npr, nrows
1805
+ };
1806
+
1807
+ const uint32_t total_wg_init = npr * nrows;
1808
+ const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1809
+ const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
1810
+ const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
1811
+ std::vector<wgpu::BindGroupEntry> init_entries = {
1812
+ { .binding = 0,
1813
+ .buffer = ggml_webgpu_tensor_buf(src),
1814
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1815
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1816
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
1817
+ };
1818
+
1819
+ pipelines.push_back(argsort_pipeline);
1820
+ params_list.push_back(std::move(init_params));
1821
+ entries_list.push_back(std::move(init_entries));
1822
+ workgroups_list.push_back({ wg_x_init, wg_y_init });
1823
+
1824
+ if (merge_passes == 0) {
1825
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
1826
+ entries_list, workgroups_list);
1827
+ }
1828
+
1829
+ bool in_is_tmp = start_in_tmp;
1830
+ uint32_t len = block_size;
1831
+ while (len < out_ne0) {
1832
+ const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
1833
+
1834
+ const bool out_is_tmp = !in_is_tmp;
1835
+ const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst;
1836
+ const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst;
1837
+ const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1838
+ const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1839
+ const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size;
1840
+ const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size;
1841
+ const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
1842
+ const uint32_t stride_out1 = top_k_out;
1843
+ const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
1844
+ const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
1845
+
1846
+ std::vector<uint32_t> merge_params = { offset_src,
1847
+ offset_in,
1848
+ offset_out,
1849
+ stride_src1,
1850
+ stride_src2,
1851
+ stride_src3,
1852
+ stride_idx1,
1853
+ stride_idx2,
1854
+ stride_idx3,
1855
+ stride_out1,
1856
+ stride_out2,
1857
+ stride_out3,
1858
+ out_ne0,
1859
+ (uint32_t) src->ne[1],
1860
+ (uint32_t) src->ne[2],
1861
+ top_k_out,
1862
+ len,
1863
+ nm,
1864
+ nrows };
1865
+
1866
+ std::vector<wgpu::BindGroupEntry> merge_entries = {
1867
+ { .binding = 0,
1868
+ .buffer = ggml_webgpu_tensor_buf(src),
1869
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1870
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1871
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
1872
+ { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
1873
+ };
1874
+
1875
+ const uint32_t total_wg_merge = nm * nrows;
1876
+ const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
1877
+ const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
1878
+ workgroups_list.push_back({ wg_x_merge, wg_y_merge });
1879
+ pipelines.push_back(argsort_merge_pipeline);
1880
+ params_list.push_back(std::move(merge_params));
1881
+ entries_list.push_back(std::move(merge_entries));
1882
+
1883
+ len <<= 1;
1884
+ in_is_tmp = !in_is_tmp;
1885
+ }
1886
+
1887
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
1888
+ workgroups_list);
1889
+ }
1890
+
1891
+ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1892
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1893
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1894
+ (uint32_t) src->ne[0] };
1895
+
1896
+ std::vector<wgpu::BindGroupEntry> entries = {
1897
+ { .binding = 0,
1898
+ .buffer = ggml_webgpu_tensor_buf(src),
1899
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1900
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1901
+ { .binding = 1,
1902
+ .buffer = ggml_webgpu_tensor_buf(dst),
1903
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1904
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1905
+ };
1906
+
1907
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1908
+ .src0 = src,
1909
+ .src1 = nullptr,
1910
+ .dst = dst,
1911
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1912
+ };
1913
+
1914
+ webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
1915
+ uint32_t wg_x = ggml_nrows(dst);
1916
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1917
+ }
1918
+
1919
+ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1920
+ bool total_sum = dst->op == GGML_OP_SUM;
1921
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1922
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1923
+ total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1924
+ total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1925
+ total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1926
+ total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
1927
+ total_sum ? 1 : (uint32_t) src->ne[1],
1928
+ total_sum ? 1 : (uint32_t) src->ne[2] };
1929
+
1930
+ std::vector<wgpu::BindGroupEntry> entries = {
1931
+ { .binding = 0,
1932
+ .buffer = ggml_webgpu_tensor_buf(src),
1933
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1934
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1935
+ { .binding = 1,
1936
+ .buffer = ggml_webgpu_tensor_buf(dst),
1937
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1938
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1939
+ };
1940
+
1941
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1942
+ .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1943
+ };
1944
+
1945
+ webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
1946
+
1947
+ uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
1948
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1949
+ }
1950
+
1951
+ // Returns the encoded command, or std::nullopt if the operation is a no-op
1952
+ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
1953
+ if (ggml_is_empty(node)) {
1954
+ return std::nullopt;
1955
+ }
1956
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
1957
+ return std::nullopt;
1958
+ }
1959
+ WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
1960
+
1961
+ ggml_tensor * src0 = node->src[0];
1962
+ ggml_tensor * src1 = node->src[1];
1963
+ ggml_tensor * src2 = node->src[2];
1964
+
1965
+ switch (node->op) {
1966
+ // no-ops
1967
+ case GGML_OP_NONE:
1968
+ case GGML_OP_VIEW:
1969
+ case GGML_OP_PERMUTE:
1970
+ case GGML_OP_TRANSPOSE:
1971
+ case GGML_OP_RESHAPE:
1972
+ return std::nullopt;
1973
+ case GGML_OP_CPY:
1974
+ case GGML_OP_CONT:
1975
+ return ggml_webgpu_cpy(ctx, src0, node);
1976
+ case GGML_OP_SET_ROWS:
1977
+ return ggml_webgpu_set_rows(ctx, src0, src1, node);
1978
+ case GGML_OP_GET_ROWS:
1979
+ return ggml_webgpu_get_rows(ctx, src0, src1, node);
1980
+ case GGML_OP_MUL_MAT:
1981
+ return ggml_webgpu_mul_mat(ctx, src0, src1, node);
1982
+ case GGML_OP_FLASH_ATTN_EXT:
1983
+ #ifndef __EMSCRIPTEN__
1984
+ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
1985
+ #else
1986
+ return std::nullopt;
1987
+ #endif
1988
+ case GGML_OP_ADD:
1989
+ case GGML_OP_SUB:
1990
+ case GGML_OP_MUL:
1991
+ case GGML_OP_DIV:
1992
+ return ggml_webgpu_binary_op(ctx, src0, src1, node);
1993
+ case GGML_OP_RMS_NORM:
1994
+ return ggml_webgpu_rms_norm(ctx, src0, node);
1995
+ case GGML_OP_ROPE:
1996
+ return ggml_webgpu_rope(ctx, src0, src1, src2, node);
1997
+ case GGML_OP_GLU:
1998
+ return ggml_webgpu_glu(ctx, src0, src1, node);
1999
+ case GGML_OP_SCALE:
2000
+ return ggml_webgpu_scale(ctx, src0, node);
2001
+ case GGML_OP_SOFT_MAX:
2002
+ return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
2003
+ case GGML_OP_UNARY:
2004
+ return ggml_webgpu_unary_op(ctx, src0, node);
2005
+ case GGML_OP_CLAMP:
2006
+ return ggml_webgpu_unary_op(ctx, src0, node);
2007
+ case GGML_OP_FILL:
2008
+ return ggml_webgpu_unary_op(ctx, src0, node);
2009
+ case GGML_OP_LOG:
2010
+ return ggml_webgpu_unary_op(ctx, src0, node);
2011
+ case GGML_OP_SQR:
2012
+ return ggml_webgpu_unary_op(ctx, src0, node);
2013
+ case GGML_OP_SQRT:
2014
+ return ggml_webgpu_unary_op(ctx, src0, node);
2015
+ case GGML_OP_SIN:
2016
+ return ggml_webgpu_unary_op(ctx, src0, node);
2017
+ case GGML_OP_COS:
2018
+ return ggml_webgpu_unary_op(ctx, src0, node);
2019
+ case GGML_OP_PAD:
2020
+ return ggml_webgpu_pad(ctx, src0, node);
2021
+ case GGML_OP_ARGMAX:
2022
+ return ggml_webgpu_argmax(ctx, src0, node);
2023
+ case GGML_OP_ARGSORT:
2024
+ return ggml_webgpu_argsort(ctx, src0, node);
2025
+ case GGML_OP_TOP_K:
2026
+ // we reuse the same argsort implementation for top_k
2027
+ return ggml_webgpu_argsort(ctx, src0, node);
2028
+ case GGML_OP_CUMSUM:
2029
+ return ggml_webgpu_cumsum(ctx, src0, node);
2030
+ case GGML_OP_SUM:
2031
+ case GGML_OP_SUM_ROWS:
2032
+ return ggml_webgpu_sum_rows(ctx, src0, node);
2033
+ default:
2034
+ return std::nullopt;
2035
+ }
2036
+ }
2037
+
2038
+ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2039
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
2040
+
2041
+ ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
2042
+ webgpu_context ctx = backend_ctx->webgpu_ctx;
2043
+
2044
+ WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
2045
+
2046
+ ctx->global_ctx->inflight_threads++;
2047
+
2048
+ std::vector<webgpu_command> commands;
2049
+ std::vector<webgpu_submission_futures> futures;
2050
+ for (int i = 0; i < cgraph->n_nodes; i++) {
2051
+ if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
2052
+ commands.push_back(*cmd);
2053
+ }
2054
+ // compute the batch size based on the number of inflight threads
2055
+ uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
2056
+ uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
2057
+ WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
2058
+ if (commands.size() >= batch_size) {
2059
+ futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
2060
+ &ctx->set_rows_error_buf_pool));
2061
+ // Process events and check for completed submissions
2062
+ ctx->global_ctx->instance.ProcessEvents();
2063
+ ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
2064
+ commands.clear();
2065
+ }
2066
+ }
2067
+ if (!commands.empty()) {
2068
+ webgpu_submission_futures new_futures =
2069
+ ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
2070
+ futures.push_back(new_futures);
2071
+ }
2072
+
2073
+ ggml_backend_webgpu_wait(ctx->global_ctx, futures);
2074
+ ctx->global_ctx->inflight_threads--;
2075
+ WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
2076
+ return GGML_STATUS_SUCCESS;
2077
+ }
2078
+
2079
+ static ggml_backend_i ggml_backend_webgpu_i = {
2080
+ /* .get_name = */ ggml_backend_webgpu_name,
2081
+ /* .free = */ ggml_backend_webgpu_free,
2082
+ /* .set_tensor_async = */ NULL,
2083
+ /* .get_tensor_async = */ NULL,
2084
+ /* .cpy_tensor_async = */ NULL,
2085
+ /* .synchronize = */ NULL,
2086
+ /* .graph_plan_create = */ NULL,
2087
+ /* .graph_plan_free = */ NULL,
2088
+ /* .graph_plan_update = */ NULL,
2089
+ /* .graph_plan_compute = */ NULL,
2090
+ /* .graph_compute = */ ggml_backend_webgpu_graph_compute,
2091
+ /* .event_record = */ NULL,
2092
+ /* .event_wait = */ NULL,
2093
+ /* .graph_optimize = */ NULL,
2094
+ };
2095
+
2096
+ /* End GGML Backend Interface */
2097
+
2098
+ /* GGML Backend Buffer Interface */
2099
+
2100
+ static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2101
+ ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
2102
+ if (ctx != nullptr && ctx->buffer != nullptr) {
2103
+ ctx->buffer.Destroy();
2104
+ delete ctx;
2105
+ }
2106
+ }
2107
+
2108
+ // Returns the "fake" base pointer.
2109
+ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
2110
+ GGML_UNUSED(buffer);
2111
+ return webgpu_ptr_base;
2112
+ }
2113
+
2114
+ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
2115
+ ggml_tensor * tensor,
2116
+ uint8_t value,
2117
+ size_t offset,
2118
+ size_t size) {
2119
+ if (size == 0) {
2120
+ WEBGPU_LOG_DEBUG(
2121
+ "ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
2122
+ "nothing to do.");
2123
+ return;
2124
+ }
2125
+
2126
+ WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
2127
+
2128
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2129
+
2130
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
2131
+ << ", " << offset << ", " << size << ")");
2132
+
2133
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2134
+
2135
+ // This is a trick to set all bytes of a u32 to the same 1 byte value.
2136
+ uint32_t val32 = (uint32_t) value * 0x01010101;
2137
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
2138
+ WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
2139
+ }
2140
+
2141
+ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
2142
+ ggml_tensor * tensor,
2143
+ const void * data,
2144
+ size_t offset,
2145
+ size_t size) {
2146
+ WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
2147
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2148
+
2149
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2150
+ << ", " << offset << ", " << size << ")");
2151
+
2152
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2153
+
2154
+ buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
2155
+
2156
+ if (size % 4 != 0) {
2157
+ // If size is not a multiple of 4, we need to memset the remaining bytes
2158
+ size_t remaining_size = size % 4;
2159
+
2160
+ // pack the remaining bytes into a uint32_t
2161
+ uint32_t val32 = 0;
2162
+
2163
+ for (size_t i = 0; i < remaining_size; i++) {
2164
+ ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
2165
+ }
2166
+ // memset the remaining bytes
2167
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
2168
+ total_offset + (size - remaining_size), remaining_size);
2169
+ } else {
2170
+ // wait for WriteBuffer to complete
2171
+ buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
2172
+ wgpu::CallbackMode::AllowSpontaneous,
2173
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
2174
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
2175
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
2176
+ std::string(message).c_str());
2177
+ }
2178
+ }),
2179
+ UINT64_MAX);
2180
+ }
2181
+ WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
2182
+ }
2183
+
2184
+ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
2185
+ const ggml_tensor * tensor,
2186
+ void * data,
2187
+ size_t offset,
2188
+ size_t size) {
2189
+ WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
2190
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2191
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2192
+ << ", " << offset << ", " << size << ")");
2193
+ wgpu::Device device = buf_ctx->global_ctx->device;
2194
+
2195
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2196
+
2197
+ size_t final_size = size;
2198
+ if (size % 4 != 0) {
2199
+ // If size is not a multiple of 4, we need to round it up to the next
2200
+ // multiple of 4
2201
+ final_size = size + (4 - (size % 4));
2202
+ }
2203
+
2204
+ std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
2205
+
2206
+ if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
2207
+ buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
2208
+ // Create a new staging buffer if it doesn't exist or is too small
2209
+ if (buf_ctx->global_ctx->get_tensor_staging_buf) {
2210
+ buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
2211
+ }
2212
+ ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
2213
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
2214
+ }
2215
+
2216
+ // Copy the data from the buffer to the staging buffer
2217
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
2218
+ encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
2219
+ final_size);
2220
+ wgpu::CommandBuffer commands = encoder.Finish();
2221
+
2222
+ // Submit the command buffer to the queue
2223
+ buf_ctx->global_ctx->queue.Submit(1, &commands);
2224
+
2225
+ // Map the staging buffer to read the data
2226
+ ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
2227
+ wgpu::MapMode::Read, 0, final_size);
2228
+ // Must specify size here since the staging buffer might be larger than the tensor size
2229
+ const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
2230
+
2231
+ // Copy the data from the mapped range to the output buffer
2232
+ std::memcpy(data, mapped_range, size);
2233
+ buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
2234
+ WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
2235
+ }
2236
+
2237
+ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2238
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
2239
+ WEBGPU_CPU_PROFILE_TOTAL_START(clear);
2240
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2241
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
2242
+ WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
2243
+ }
2244
+
2245
+ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
2246
+ /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
2247
+ /* .get_base = */ ggml_backend_webgpu_buffer_get_base,
2248
+ /* .init_tensor = */ NULL, // TODO: optional, needed?
2249
+ /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
2250
+ /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
2251
+ /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
2252
+ /* .cpy_tensor = */ NULL, // TODO: optional, implement this
2253
+ /* .clear = */ ggml_backend_webgpu_buffer_clear,
2254
+ /* .reset = */ NULL, // TODO: optional, think it coordinates with
2255
+ // .init_tensor
2256
+ };
2257
+
2258
+ /* End GGML Backend Buffer Interface */
2259
+
2260
+ /* GGML Backend Buffer Type Interface */
2261
+
2262
+ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2263
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2264
+ return ctx->device_name.c_str();
2265
+ }
2266
+
2267
+ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
2268
+ size_t size) {
2269
+ static std::atomic<int> buffer_count;
2270
+ int buffer_id = buffer_count++;
2271
+ std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
2272
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
2273
+
2274
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2275
+ wgpu::Buffer buf;
2276
+ ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
2277
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
2278
+ buf_name.c_str());
2279
+
2280
+ ggml_backend_webgpu_buffer_context * buf_ctx =
2281
+ new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
2282
+
2283
+ return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
2284
+ }
2285
+
2286
+ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2287
+ ggml_backend_webgpu_device_context * dev_ctx =
2288
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2289
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
2290
+ }
2291
+
2292
+ // maxBufferSize might be larger, but you can't bind more than
2293
+ // maxStorageBufferBindingSize to a single binding.
2294
+ static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
2295
+ ggml_backend_webgpu_device_context * dev_ctx =
2296
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2297
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
2298
+ }
2299
+
2300
+ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
2301
+ const ggml_tensor * tensor) {
2302
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2303
+ size_t res = ggml_nbytes(tensor);
2304
+ switch (tensor->op) {
2305
+ case GGML_OP_ARGSORT:
2306
+ res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2307
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
2308
+ break;
2309
+ case GGML_OP_TOP_K:
2310
+ {
2311
+ const ggml_tensor * src0 = tensor->src[0];
2312
+ if (src0) {
2313
+ const size_t full = sizeof(int32_t) * ggml_nelements(src0);
2314
+ res = ROUNDUP_POW2(
2315
+ full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2316
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
2317
+ }
2318
+ }
2319
+ break;
2320
+ default:
2321
+ break;
2322
+ }
2323
+ return res;
2324
+ }
2325
+
2326
+ /* End GGML Backend Buffer Type Interface */
2327
+
2328
+ /* GGML Backend Device Interface */
2329
+
2330
+ static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
2331
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2332
+ return ctx->device_name.c_str();
2333
+ }
2334
+
2335
+ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
2336
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2337
+ return ctx->device_desc.c_str();
2338
+ }
2339
+
2340
+ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2341
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2342
+ // TODO: for now, return maxBufferSize as both free and total memory
2343
+ // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
2344
+ uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
2345
+ // If we're on a 32-bit system, clamp to UINTPTR_MAX
2346
+ #if UINTPTR_MAX < UINT64_MAX
2347
+ uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
2348
+ if (max_buffer_size > max_ptr_size) {
2349
+ max_buffer_size = max_ptr_size;
2350
+ }
2351
+ #endif
2352
+ *free = static_cast<size_t>(max_buffer_size);
2353
+ *total = static_cast<size_t>(max_buffer_size);
2354
+ }
2355
+
2356
+ static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
2357
+ GGML_UNUSED(dev);
2358
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
2359
+ }
2360
+
2361
+ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2362
+ props->name = ggml_backend_webgpu_device_get_name(dev);
2363
+ props->description = ggml_backend_webgpu_device_get_description(dev);
2364
+ props->type = ggml_backend_webgpu_device_get_type(dev);
2365
+ ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
2366
+ props->caps = {
2367
+ /* .async = */ false,
2368
+ /* .host_buffer = */ false,
2369
+ /* .buffer_from_host_ptr = */ false,
2370
+ /* .events = */ false,
2371
+ };
2372
+ }
2373
+
2374
+ static ggml_guid_t ggml_backend_webgpu_guid(void) {
2375
+ static const char * guid_str = "__ggml_webgpu :)";
2376
+ return reinterpret_cast<ggml_guid_t>((void *) guid_str);
2377
+ }
2378
+
2379
+ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
2380
+ // we use the maximum workgroup size for the memset pipeline
2381
+ size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2382
+ // Size the bytes_per_thread so that the largest buffer size can be handled
2383
+ ctx->capabilities.memset_bytes_per_thread =
2384
+ CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
2385
+ std::vector<wgpu::ConstantEntry> constants(2);
2386
+ constants[0].key = "wg_size";
2387
+ constants[0].value = WEBGPU_MAX_WG_SIZE;
2388
+ constants[1].key = "bytes_per_thread";
2389
+ constants[1].value = ctx->capabilities.memset_bytes_per_thread;
2390
+ ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
2391
+ }
2392
+
2393
+ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
2394
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2395
+
2396
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
2397
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2398
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
2399
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
2400
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
2401
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2402
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
2403
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2404
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
2405
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
2406
+ }
2407
+
2408
+ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
2409
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2410
+
2411
+ webgpu_ctx->rms_norm_pipelines[0] =
2412
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2413
+ webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
2414
+ webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2415
+ }
2416
+
2417
+ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
2418
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2419
+
2420
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
2421
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2422
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
2423
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2424
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
2425
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2426
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
2427
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2428
+
2429
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
2430
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2431
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
2432
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2433
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
2434
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2435
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
2436
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2437
+ }
2438
+
2439
+ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
2440
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2441
+
2442
+ // REGLU
2443
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
2444
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2445
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
2446
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2447
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
2448
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2449
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
2450
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2451
+
2452
+ // GEGLU
2453
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
2454
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2455
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
2456
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2457
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
2458
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2459
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
2460
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2461
+
2462
+ // SWIGLU
2463
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
2464
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2465
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
2466
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2467
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2468
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2469
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2470
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2471
+
2472
+ // SWIGLU_OAI
2473
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
2474
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2475
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2476
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2477
+
2478
+ // GEGLU_ERF
2479
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
2480
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2481
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
2482
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2483
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2484
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2485
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2486
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2487
+
2488
+ // GEGLU_QUICK
2489
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
2490
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2491
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
2492
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2493
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2494
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2495
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2496
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2497
+ }
2498
+
2499
+ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
2500
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2501
+
2502
+ // f32 (no mask)
2503
+ webgpu_ctx->soft_max_pipelines[2][0][0] =
2504
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2505
+ webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
2506
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2507
+ webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
2508
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2509
+ webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
2510
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2511
+
2512
+ // f32 mask (mask_type = 0)
2513
+ webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
2514
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2515
+ webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
2516
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2517
+ webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
2518
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2519
+ webgpu_ctx->soft_max_pipelines[0][1][1] =
2520
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
2521
+ "soft_max_f32_mask_f32_sink_inplace", constants);
2522
+
2523
+ // f16 mask (mask_type = 1)
2524
+ webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
2525
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2526
+ webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
2527
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2528
+ webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
2529
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2530
+ webgpu_ctx->soft_max_pipelines[1][1][1] =
2531
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
2532
+ "soft_max_f32_mask_f16_sink_inplace", constants);
2533
+ }
2534
+
2535
+ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
2536
+ wgpu::RequestAdapterOptions options = {};
2537
+
2538
+ #ifndef __EMSCRIPTEN__
2539
+ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2540
+ const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2541
+ wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2542
+ adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
2543
+ adapterTogglesDesc.enabledToggleCount = 2;
2544
+ options.nextInChain = &adapterTogglesDesc;
2545
+ #endif
2546
+
2547
+ ctx->webgpu_global_ctx->instance.WaitAny(
2548
+ ctx->webgpu_global_ctx->instance.RequestAdapter(
2549
+ &options, wgpu::CallbackMode::AllowSpontaneous,
2550
+ [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
2551
+ if (status != wgpu::RequestAdapterStatus::Success) {
2552
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
2553
+ return;
2554
+ }
2555
+ ctx->webgpu_global_ctx->adapter = std::move(adapter);
2556
+ }),
2557
+ UINT64_MAX);
2558
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
2559
+
2560
+ ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
2561
+
2562
+ wgpu::AdapterInfo info{};
2563
+ #ifndef __EMSCRIPTEN__
2564
+ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2565
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2566
+ info.nextInChain = &subgroup_matrix_configs;
2567
+ }
2568
+ #endif
2569
+ ctx->webgpu_global_ctx->adapter.GetInfo(&info);
2570
+ wgpu::SupportedFeatures features;
2571
+ ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
2572
+ // we require f16 support
2573
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2574
+
2575
+ #ifndef __EMSCRIPTEN__
2576
+ // Only support square f16 matrices of size 8 or 16 for now
2577
+ bool valid_subgroup_matrix_config = false;
2578
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2579
+ for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2580
+ const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2581
+ if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2582
+ config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2583
+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2584
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
2585
+ ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
2586
+ ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
2587
+ valid_subgroup_matrix_config = true;
2588
+ break;
2589
+ }
2590
+ }
2591
+ }
2592
+ ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
2593
+ #endif
2594
+
2595
+ // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2596
+ // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2597
+ ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
2598
+ // Initialize device
2599
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
2600
+
2601
+ #ifndef __EMSCRIPTEN__
2602
+ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2603
+ if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
2604
+ required_features.push_back(wgpu::FeatureName::Subgroups);
2605
+ required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2606
+ }
2607
+ #endif
2608
+
2609
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2610
+ required_features.push_back(wgpu::FeatureName::TimestampQuery);
2611
+ #endif
2612
+
2613
+ wgpu::DeviceDescriptor dev_desc;
2614
+ dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits;
2615
+ dev_desc.requiredFeatures = required_features.data();
2616
+ dev_desc.requiredFeatureCount = required_features.size();
2617
+ dev_desc.SetDeviceLostCallback(
2618
+ wgpu::CallbackMode::AllowSpontaneous,
2619
+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
2620
+ if (reason == wgpu::DeviceLostReason::Destroyed) {
2621
+ return;
2622
+ }
2623
+ GGML_UNUSED(device);
2624
+ GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2625
+ std::string(message).c_str());
2626
+ });
2627
+ dev_desc.SetUncapturedErrorCallback(
2628
+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
2629
+ GGML_UNUSED(device);
2630
+ GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2631
+ std::string(message).c_str());
2632
+ });
2633
+
2634
+ #ifndef __EMSCRIPTEN__
2635
+ // Enable Dawn-specific toggles to increase native performance
2636
+ // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2637
+ // only for native performance?
2638
+ const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2639
+ "disable_polyfills_on_integer_div_and_mod" };
2640
+ const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2641
+ wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2642
+ deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2643
+ deviceTogglesDesc.enabledToggleCount = 4;
2644
+ deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2645
+ deviceTogglesDesc.disabledToggleCount = 1;
2646
+
2647
+ dev_desc.nextInChain = &deviceTogglesDesc;
2648
+ #endif
2649
+
2650
+ ctx->webgpu_global_ctx->instance.WaitAny(
2651
+ ctx->webgpu_global_ctx->adapter.RequestDevice(
2652
+ &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
2653
+ [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
2654
+ if (status != wgpu::RequestDeviceStatus::Success) {
2655
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
2656
+ return;
2657
+ }
2658
+ ctx->webgpu_global_ctx->device = std::move(device);
2659
+ }),
2660
+ UINT64_MAX);
2661
+ GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
2662
+
2663
+ ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
2664
+ ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2665
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2666
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2667
+ ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
2668
+
2669
+ #ifdef GGML_WEBGPU_GPU_PROFILE
2670
+ // Initialize buffer pool for timestamp queries, used for profiling
2671
+ ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
2672
+ ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
2673
+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
2674
+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
2675
+ #endif
2676
+
2677
+ GGML_LOG_INFO(
2678
+ "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
2679
+ "device_desc: %s\n",
2680
+ info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
2681
+ std::string(info.device).c_str(), std::string(info.description).c_str());
2682
+ return true;
2683
+ }
2684
+
2685
+ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
2686
+ ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
2687
+ webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
2688
+ webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
2689
+ webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
2690
+ webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2691
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2692
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2693
+ webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
2694
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
2695
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
2696
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
2697
+
2698
+ ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
2699
+ ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
2700
+ ggml_webgpu_init_rope_pipeline(webgpu_ctx);
2701
+ ggml_webgpu_init_glu_pipeline(webgpu_ctx);
2702
+ ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
2703
+ #ifdef GGML_WEBGPU_DEBUG
2704
+ // Initialize debug buffers
2705
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
2706
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
2707
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
2708
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
2709
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
2710
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
2711
+ #endif
2712
+ return webgpu_ctx;
2713
+ }
2714
+
2715
+ static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
2716
+ GGML_UNUSED(params);
2717
+
2718
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
2719
+
2720
+ ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2721
+
2722
+ auto * backend_ctx = new ggml_backend_webgpu_context();
2723
+ backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
2724
+ backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
2725
+
2726
+ // See GGML Backend Interface section
2727
+ auto * backend = new ggml_backend();
2728
+ *backend = {
2729
+ /* .guid = */ ggml_backend_webgpu_guid(),
2730
+ /* .interface = */ ggml_backend_webgpu_i,
2731
+ /* .device = */ dev,
2732
+ /* .context = */ backend_ctx,
2733
+ };
2734
+ return backend;
2735
+ }
2736
+
2737
+ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
2738
+ // See GGML Backend Buffer Type Interface section
2739
+
2740
+ static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
2741
+ /* .iface = */ {
2742
+ /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
2743
+ /* .alloc_buffer = */
2744
+ ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
2745
+ ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
2746
+ ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
2747
+ ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
2748
+ },
2749
+ /* .device = */
2750
+ dev,
2751
+ /* .context = */
2752
+ NULL
2753
+ };
2754
+
2755
+ return &ggml_backend_webgpu_buffer_type;
2756
+ }
2757
+
2758
+ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2759
+ GGML_UNUSED(dev);
2760
+ return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
2761
+ }
2762
+
2763
+ static bool ggml_webgpu_supported_qtype(ggml_type type) {
2764
+ switch (type) {
2765
+ case GGML_TYPE_Q4_0:
2766
+ case GGML_TYPE_Q4_1:
2767
+ case GGML_TYPE_Q5_0:
2768
+ case GGML_TYPE_Q5_1:
2769
+ case GGML_TYPE_Q8_0:
2770
+ case GGML_TYPE_Q2_K:
2771
+ case GGML_TYPE_Q3_K:
2772
+ case GGML_TYPE_Q4_K:
2773
+ case GGML_TYPE_Q5_K:
2774
+ case GGML_TYPE_Q6_K:
2775
+ case GGML_TYPE_IQ2_XXS:
2776
+ case GGML_TYPE_IQ2_XS:
2777
+ case GGML_TYPE_IQ2_S:
2778
+ case GGML_TYPE_IQ3_XXS:
2779
+ case GGML_TYPE_IQ3_S:
2780
+ case GGML_TYPE_IQ1_S:
2781
+ case GGML_TYPE_IQ1_M:
2782
+ case GGML_TYPE_IQ4_NL:
2783
+ case GGML_TYPE_IQ4_XS:
2784
+ return true;
2785
+ default:
2786
+ return false;
2787
+ }
2788
+ }
2789
+
2790
+ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2791
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2792
+
2793
+ ggml_tensor * src0 = op->src[0];
2794
+ ggml_tensor * src1 = op->src[1];
2795
+ ggml_tensor * src2 = op->src[2];
2796
+
2797
+ // on smaller devices (or CI), tensors may be larger than the max storage buffer size
2798
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
2799
+ (src0 != nullptr &&
2800
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
2801
+ (src1 != nullptr &&
2802
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
2803
+ return false;
2804
+ }
2805
+
2806
+ bool supports_op = false;
2807
+ switch (op->op) {
2808
+ case GGML_OP_NONE:
2809
+ case GGML_OP_VIEW:
2810
+ case GGML_OP_PERMUTE:
2811
+ case GGML_OP_TRANSPOSE:
2812
+ case GGML_OP_RESHAPE:
2813
+ supports_op = true;
2814
+ break;
2815
+ case GGML_OP_ADD:
2816
+ case GGML_OP_SUB:
2817
+ case GGML_OP_MUL:
2818
+ case GGML_OP_DIV:
2819
+ // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
2820
+ // see https://github.com/ggml-org/llama.cpp/pull/16857
2821
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
2822
+ (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
2823
+ break;
2824
+ case GGML_OP_CPY:
2825
+ case GGML_OP_CONT:
2826
+ supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
2827
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
2828
+ (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
2829
+ break;
2830
+ case GGML_OP_SET_ROWS:
2831
+ supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
2832
+ (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
2833
+ break;
2834
+ case GGML_OP_GET_ROWS:
2835
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
2836
+ supports_op = (op->type == GGML_TYPE_F32);
2837
+ } else if (src0->type == GGML_TYPE_I32) {
2838
+ supports_op = op->type == GGML_TYPE_I32;
2839
+ }
2840
+ break;
2841
+ case GGML_OP_MUL_MAT:
2842
+ {
2843
+ switch (src1->type) {
2844
+ case GGML_TYPE_F16:
2845
+ supports_op |= (src0->type == GGML_TYPE_F16);
2846
+ break;
2847
+ case GGML_TYPE_F32:
2848
+ switch (src0->type) {
2849
+ case GGML_TYPE_F32:
2850
+ case GGML_TYPE_F16:
2851
+ case GGML_TYPE_Q4_0:
2852
+ case GGML_TYPE_Q4_1:
2853
+ case GGML_TYPE_Q5_0:
2854
+ case GGML_TYPE_Q5_1:
2855
+ case GGML_TYPE_Q8_0:
2856
+ case GGML_TYPE_Q2_K:
2857
+ case GGML_TYPE_Q3_K:
2858
+ case GGML_TYPE_Q4_K:
2859
+ case GGML_TYPE_Q5_K:
2860
+ case GGML_TYPE_Q6_K:
2861
+ case GGML_TYPE_IQ2_XXS:
2862
+ case GGML_TYPE_IQ2_XS:
2863
+ case GGML_TYPE_IQ2_S:
2864
+ case GGML_TYPE_IQ3_XXS:
2865
+ case GGML_TYPE_IQ3_S:
2866
+ case GGML_TYPE_IQ1_S:
2867
+ case GGML_TYPE_IQ1_M:
2868
+ case GGML_TYPE_IQ4_NL:
2869
+ case GGML_TYPE_IQ4_XS:
2870
+ supports_op = true;
2871
+ break;
2872
+ default:
2873
+ break;
2874
+ }
2875
+ default:
2876
+ break;
2877
+ }
2878
+ break;
2879
+ }
2880
+ case GGML_OP_FLASH_ATTN_EXT:
2881
+ {
2882
+ #ifndef __EMSCRIPTEN__
2883
+ if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
2884
+ break;
2885
+ }
2886
+ // Head dimensions must fit in workgroup memory with minimum tile sizes
2887
+ size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
2888
+ const bool has_mask = op->src[3] != nullptr;
2889
+ const bool kv_direct = src1->type == GGML_TYPE_F16 &&
2890
+ (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
2891
+ (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
2892
+ const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
2893
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
2894
+ (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
2895
+ if (min_bytes > limit_bytes) {
2896
+ break;
2897
+ }
2898
+
2899
+ supports_op = src0->type == GGML_TYPE_F32 &&
2900
+ (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
2901
+ src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
2902
+ src2->type == src1->type && op->type == GGML_TYPE_F32;
2903
+ #endif
2904
+ break;
2905
+ }
2906
+ case GGML_OP_RMS_NORM:
2907
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
2908
+ break;
2909
+ case GGML_OP_ROPE:
2910
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
2911
+ break;
2912
+ case GGML_OP_GLU:
2913
+ switch (ggml_get_glu_op(op)) {
2914
+ case GGML_GLU_OP_REGLU:
2915
+ case GGML_GLU_OP_GEGLU:
2916
+ case GGML_GLU_OP_SWIGLU:
2917
+ case GGML_GLU_OP_GEGLU_ERF:
2918
+ case GGML_GLU_OP_GEGLU_QUICK:
2919
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
2920
+ break;
2921
+ case GGML_GLU_OP_SWIGLU_OAI:
2922
+ supports_op = op->type == GGML_TYPE_F32;
2923
+ break;
2924
+ default:
2925
+ break;
2926
+ }
2927
+ break;
2928
+ case GGML_OP_SCALE:
2929
+ supports_op = op->type == GGML_TYPE_F32;
2930
+ break;
2931
+ case GGML_OP_SOFT_MAX:
2932
+ supports_op = op->type == GGML_TYPE_F32;
2933
+ break;
2934
+ case GGML_OP_UNARY:
2935
+ {
2936
+ const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
2937
+
2938
+ switch (UNARY_OP) {
2939
+ case GGML_UNARY_OP_ABS:
2940
+ case GGML_UNARY_OP_SGN:
2941
+ case GGML_UNARY_OP_NEG:
2942
+ case GGML_UNARY_OP_STEP:
2943
+ case GGML_UNARY_OP_TANH:
2944
+ case GGML_UNARY_OP_ELU:
2945
+ case GGML_UNARY_OP_RELU:
2946
+ case GGML_UNARY_OP_SIGMOID:
2947
+ case GGML_UNARY_OP_GELU:
2948
+ case GGML_UNARY_OP_GELU_QUICK:
2949
+ case GGML_UNARY_OP_SILU:
2950
+ case GGML_UNARY_OP_HARDSWISH:
2951
+ case GGML_UNARY_OP_HARDSIGMOID:
2952
+ case GGML_UNARY_OP_EXP:
2953
+ case GGML_UNARY_OP_GELU_ERF:
2954
+ case GGML_UNARY_OP_SOFTPLUS:
2955
+ case GGML_UNARY_OP_EXPM1:
2956
+ case GGML_UNARY_OP_FLOOR:
2957
+ case GGML_UNARY_OP_CEIL:
2958
+ case GGML_UNARY_OP_ROUND:
2959
+ case GGML_UNARY_OP_TRUNC:
2960
+ case GGML_UNARY_OP_XIELU:
2961
+ supports_op =
2962
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2963
+ break;
2964
+ default:
2965
+ break;
2966
+ }
2967
+ }
2968
+ break;
2969
+ case GGML_OP_CLAMP:
2970
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2971
+ break;
2972
+ case GGML_OP_FILL:
2973
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
2974
+ break;
2975
+ case GGML_OP_LOG:
2976
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2977
+ break;
2978
+ case GGML_OP_SQR:
2979
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2980
+ break;
2981
+ case GGML_OP_SQRT:
2982
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2983
+ break;
2984
+ case GGML_OP_SIN:
2985
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2986
+ break;
2987
+ case GGML_OP_COS:
2988
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2989
+ break;
2990
+ case GGML_OP_PAD:
2991
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
2992
+ break;
2993
+ case GGML_OP_ARGMAX:
2994
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
2995
+ break;
2996
+ case GGML_OP_ARGSORT:
2997
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
2998
+ break;
2999
+ case GGML_OP_TOP_K:
3000
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3001
+ break;
3002
+ case GGML_OP_CUMSUM:
3003
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
3004
+ break;
3005
+ case GGML_OP_SUM:
3006
+ case GGML_OP_SUM_ROWS:
3007
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
3008
+ break;
3009
+ default:
3010
+ break;
3011
+ }
3012
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3013
+ (src0 != nullptr &&
3014
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3015
+ (src1 != nullptr &&
3016
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3017
+ (src2 != nullptr &&
3018
+ ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
3019
+ supports_op = false;
3020
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
3021
+ }
3022
+
3023
+ if (!supports_op) {
3024
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
3025
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3026
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3027
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3028
+ } else {
3029
+ WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
3030
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3031
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3032
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3033
+ }
3034
+ return supports_op;
3035
+ }
3036
+
3037
+ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
3038
+ /* .get_name = */ ggml_backend_webgpu_device_get_name,
3039
+ /* .get_description = */ ggml_backend_webgpu_device_get_description,
3040
+ /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
3041
+ /* .get_type = */ ggml_backend_webgpu_device_get_type,
3042
+ /* .get_props = */ ggml_backend_webgpu_device_get_props,
3043
+ /* .init_backend = */ ggml_backend_webgpu_backend_init,
3044
+ /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
3045
+ /* .get_host_buffer_type = */ NULL,
3046
+ /* .buffer_from_host_ptr = */ NULL,
3047
+ /* .supports_op = */ ggml_backend_webgpu_device_supports_op,
3048
+ /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
3049
+ /* .offload_op = */ NULL,
3050
+ /* .event_new = */ NULL,
3051
+ /* .event_free = */ NULL,
3052
+ /* .event_synchronize = */ NULL,
3053
+ };
3054
+
3055
+ /* End GGML Backend Device Interface */
3056
+
3057
+ /* GGML Backend Registration Interface */
3058
+
3059
+ static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
3060
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3061
+ return ctx->name;
3062
+ }
3063
+
3064
+ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
3065
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3066
+ return ctx->device_count;
3067
+ }
3068
+
3069
+ // Only one device is supported for now
3070
+ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
3071
+ GGML_ASSERT(index == 0);
3072
+ WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
3073
+
3074
+ WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
3075
+
3076
+ ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3077
+
3078
+ create_webgpu_device(reg_ctx);
3079
+
3080
+ static ggml_backend_webgpu_device_context device_ctx;
3081
+ device_ctx.device_name = GGML_WEBGPU_NAME;
3082
+ device_ctx.device_desc = GGML_WEBGPU_NAME;
3083
+ device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx;
3084
+ // See GGML Backend Device Interface section
3085
+ static ggml_backend_device device = {
3086
+ /* .iface = */ ggml_backend_webgpu_device_i,
3087
+ /* .reg = */ reg,
3088
+ /* .context = */ &device_ctx,
3089
+ };
3090
+
3091
+ WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
3092
+ return &device;
3093
+ }
3094
+
3095
+ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
3096
+ /* .get_name = */ ggml_backend_webgpu_reg_get_name,
3097
+ /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
3098
+ /* .get_device = */ ggml_backend_webgpu_reg_get_device,
3099
+ /* .get_proc_address = */ NULL,
3100
+ };
3101
+
3102
+ /* End GGML Backend Registration Interface */
3103
+
3104
+ ggml_backend_reg_t ggml_backend_webgpu_reg() {
3105
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
3106
+
3107
+ static ggml_backend_webgpu_reg_context ctx;
3108
+ ctx.name = GGML_WEBGPU_NAME;
3109
+ ctx.device_count = 1;
3110
+
3111
+ wgpu::InstanceDescriptor instance_descriptor{};
3112
+ std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
3113
+ instance_descriptor.requiredFeatures = instance_features.data();
3114
+ instance_descriptor.requiredFeatureCount = instance_features.size();
3115
+
3116
+ #ifndef __EMSCRIPTEN__
3117
+ const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
3118
+ wgpu::DawnTogglesDescriptor instanceTogglesDesc;
3119
+ instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
3120
+ instanceTogglesDesc.enabledToggleCount = 1;
3121
+ instance_descriptor.nextInChain = &instanceTogglesDesc;
3122
+ #endif
3123
+
3124
+ wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor);
3125
+ ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
3126
+ ctx.webgpu_global_ctx->instance = std::move(inst);
3127
+
3128
+ #ifdef __EMSCRIPTEN__
3129
+ if (ctx.webgpu_global_ctx->instance == nullptr) {
3130
+ GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
3131
+ return nullptr;
3132
+ }
3133
+ #endif
3134
+ GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
3135
+
3136
+ static ggml_backend_reg reg = {
3137
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
3138
+ /* .iface = */ ggml_backend_webgpu_reg_i,
3139
+ /* .context = */ &ctx,
3140
+ };
3141
+ return &reg;
3142
+ }
3143
+
3144
+ ggml_backend_t ggml_backend_webgpu_init(void) {
3145
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
3146
+
3147
+ return ggml_backend_webgpu_backend_init(dev, nullptr);
3148
+ }
3149
+
3150
+ GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)