whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -25,6 +25,7 @@
25
25
  #include <vector>
26
26
  #include <string>
27
27
  #include <cmath>
28
+ #include <map>
28
29
  #include <memory>
29
30
  #include <charconv>
30
31
  #include <mutex>
@@ -33,6 +34,7 @@
33
34
  #undef MAX
34
35
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
35
36
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
37
+ #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
36
38
 
37
39
  #define UNUSED(x) (void)(x)
38
40
 
@@ -331,8 +333,10 @@ struct ggml_backend_opencl_context {
331
333
 
332
334
  cl_int alignment;
333
335
  size_t max_alloc_size;
336
+ size_t max_workgroup_size;
334
337
  bool fp16_support;
335
338
  bool has_vector_subgroup_broadcast;
339
+ bool disable_fusion;
336
340
  ggml_cl_compiler_version adreno_cl_compiler_version;
337
341
 
338
342
  int adreno_wave_size;
@@ -343,6 +347,7 @@ struct ggml_backend_opencl_context {
343
347
  cl_command_queue queue;
344
348
 
345
349
  cl_program program_add;
350
+ cl_program program_add_id;
346
351
  cl_program program_clamp;
347
352
  cl_program program_cpy;
348
353
  cl_program program_cvt;
@@ -351,6 +356,7 @@ struct ggml_backend_opencl_context {
351
356
  cl_program program_gemv_noshuffle_general;
352
357
  cl_program program_gemv_noshuffle;
353
358
  cl_program program_get_rows;
359
+ cl_program program_set_rows;
354
360
  cl_program program_glu;
355
361
  cl_program program_im2col_f16;
356
362
  cl_program program_im2col_f32;
@@ -361,12 +367,16 @@ struct ggml_backend_opencl_context {
361
367
  cl_program program_mul_mv_q4_0_f32_1d_8x_flat;
362
368
  cl_program program_mul_mv_q4_0_f32_1d_16x_flat;
363
369
  cl_program program_mul_mv_q6_K;
370
+ cl_program program_mul_mv_q8_0_f32, program_mul_mv_q8_0_f32_flat;
371
+ cl_program program_mul_mv_mxfp4_f32;
372
+ cl_program program_mul_mv_mxfp4_f32_flat;
364
373
  cl_program program_mul_mv_f16_f16;
365
374
  cl_program program_mul_mv_f16_f32_1row;
366
375
  cl_program program_mul_mv_f16_f32_l4;
367
376
  cl_program program_mul_mv_f16_f32;
368
377
  cl_program program_mul_mv_f32_f32;
369
378
  cl_program program_mul;
379
+ cl_program program_mul_mat_f16_f32_tiled;
370
380
  cl_program program_div;
371
381
  cl_program program_sub;
372
382
  cl_program program_norm;
@@ -388,29 +398,48 @@ struct ggml_backend_opencl_context {
388
398
  cl_program program_tanh;
389
399
  cl_program program_upscale;
390
400
  cl_program program_concat;
401
+ cl_program program_conv_2d_f16;
402
+ cl_program program_conv_2d_f32;
403
+ cl_program program_conv_2d_f16_f32;
391
404
  cl_program program_tsembd;
392
405
  cl_program program_mul_mv_id_q4_0_f32_8x_flat;
393
-
394
- cl_kernel kernel_add, kernel_add_row;
395
- cl_kernel kernel_mul, kernel_mul_row;
396
- cl_kernel kernel_div, kernel_div_row;
397
- cl_kernel kernel_sub, kernel_sub_row;
406
+ cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
407
+ cl_program program_mul_mv_id_mxfp4_f32;
408
+ cl_program program_mul_mv_id_mxfp4_f32_flat;
409
+ cl_program program_mul_mm_f32_f32_l4_lm;
410
+ cl_program program_mul_mm_f16_f32_l4_lm;
411
+
412
+ cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
413
+ cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
414
+ cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
415
+ cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
416
+ cl_kernel kernel_add_id;
398
417
  cl_kernel kernel_scale;
399
418
  cl_kernel kernel_silu, kernel_silu_4;
400
419
  cl_kernel kernel_gelu, kernel_gelu_4;
420
+ cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
401
421
  cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
402
422
  cl_kernel kernel_relu;
403
423
  cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
404
424
  cl_kernel kernel_clamp;
405
- cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
406
- kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
407
- cl_kernel kernel_norm;
408
- cl_kernel kernel_rms_norm;
409
- cl_kernel kernel_group_norm;
425
+ cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
426
+ kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
427
+ cl_kernel kernel_norm, kernel_norm_mul_add;
428
+ cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
429
+ cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
410
430
  cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
411
431
  cl_kernel kernel_soft_max, kernel_soft_max_4;
412
432
  cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
433
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;
434
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;
435
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;
436
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;
437
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
438
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
439
+ std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
440
+ std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
413
441
  cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
442
+ cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32;
414
443
  cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
415
444
  cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
416
445
  cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
@@ -419,12 +448,17 @@ struct ggml_backend_opencl_context {
419
448
  cl_kernel kernel_mul_mat_f16_f32_1row;
420
449
  cl_kernel kernel_mul_mat_f16_f32;
421
450
  cl_kernel kernel_mul_mat_f16_f32_l4;
451
+ cl_kernel kernel_mul_mat_f16_f32_tiled;
422
452
  cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
423
453
  cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
454
+ cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
455
+ cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
424
456
  cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
425
457
  cl_kernel kernel_convert_block_q4_0_noshuffle;
426
458
  cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
427
459
  cl_kernel kernel_mul_mv_q6_K_f32;
460
+ cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
461
+ cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;
428
462
  cl_kernel kernel_im2col_f32, kernel_im2col_f16;
429
463
  cl_kernel kernel_argsort_f32_i32;
430
464
  cl_kernel kernel_sum_rows_f32;
@@ -436,8 +470,16 @@ struct ggml_backend_opencl_context {
436
470
  cl_kernel kernel_upscale_bilinear;
437
471
  cl_kernel kernel_concat_f32_contiguous;
438
472
  cl_kernel kernel_concat_f32_non_contiguous;
473
+ cl_kernel kernel_conv_2d_f16;
474
+ cl_kernel kernel_conv_2d_f32;
475
+ cl_kernel kernel_conv_2d_f16_f32;
439
476
  cl_kernel kernel_timestep_embedding;
440
477
  cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
478
+ cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
479
+ cl_kernel kernel_mul_mv_id_mxfp4_f32;
480
+ cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
481
+ cl_kernel kernel_mul_mm_f32_f32_l4_lm;
482
+ cl_kernel kernel_mul_mm_f16_f32_l4_lm;
441
483
 
442
484
  std::vector<ProfilingInfo> profiling_info;
443
485
 
@@ -528,6 +570,16 @@ struct ggml_backend_opencl_context {
528
570
  fclose(ftrace);
529
571
  }
530
572
 
573
+ size_t get_kernel_workgroup_size(cl_kernel kernel) const {
574
+ size_t workgroup_size = 0;
575
+ size_t ret_size = 0;
576
+ CL_CHECK(
577
+ clGetKernelWorkGroupInfo(kernel, device, CL_KERNEL_WORK_GROUP_SIZE,
578
+ sizeof(size_t), &workgroup_size, &ret_size));
579
+ GGML_ASSERT(sizeof(size_t) == ret_size);
580
+ return workgroup_size;
581
+ }
582
+
531
583
  void enqueue_ndrange_kernel(cl_kernel kernel, cl_uint work_dim, size_t *global_work_size, size_t *local_work_size, const ggml_tensor * tensor) {
532
584
  #ifdef GGML_OPENCL_PROFILING
533
585
  cl_event evt;
@@ -548,6 +600,7 @@ struct ggml_backend_opencl_context {
548
600
  cl_kernel kernel_transpose_32;
549
601
  cl_kernel kernel_transpose_32_16;
550
602
  cl_kernel kernel_transpose_16;
603
+ cl_kernel kernel_transpose_16_4x1;
551
604
 
552
605
  cl_mem A_s_d_max; // max scale buffer size for transpose
553
606
  cl_mem A_q_d_max; // max weight buffer size for transpose
@@ -573,6 +626,7 @@ struct ggml_backend_opencl_context {
573
626
  if (ref_count == 0) {
574
627
  #ifdef GGML_OPENCL_PROFILING
575
628
  write_profiling_info();
629
+ profiling_info.clear();
576
630
  #endif
577
631
  }
578
632
  }
@@ -647,8 +701,26 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
647
701
  backend_ctx->program_add =
648
702
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
649
703
 
650
- CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
651
- CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
704
+ CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
705
+ CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
706
+ CL_CHECK((backend_ctx->kernel_add_f16 = clCreateKernel(backend_ctx->program_add, "kernel_add_f16", &err), err));
707
+ CL_CHECK((backend_ctx->kernel_add_row_f16 = clCreateKernel(backend_ctx->program_add, "kernel_add_row_f16", &err), err));
708
+ GGML_LOG_CONT(".");
709
+ }
710
+
711
+ // add_id
712
+ {
713
+ #ifdef GGML_OPENCL_EMBED_KERNELS
714
+ const std::string kernel_src {
715
+ #include "add_id.cl.h"
716
+ };
717
+ #else
718
+ const std::string kernel_src = read_file("add_id.cl");
719
+ #endif
720
+ backend_ctx->program_add_id =
721
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
722
+
723
+ CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err));
652
724
  GGML_LOG_CONT(".");
653
725
  }
654
726
 
@@ -702,6 +774,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
702
774
  CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err));
703
775
  CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
704
776
  CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
777
+ CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
778
+ CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
779
+ CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
780
+ CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
705
781
  GGML_LOG_CONT(".");
706
782
  }
707
783
 
@@ -736,6 +812,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
736
812
 
737
813
  CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err));
738
814
  CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err));
815
+ CL_CHECK((backend_ctx->kernel_gelu_erf = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf", &err), err));
816
+ CL_CHECK((backend_ctx->kernel_gelu_erf_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf_4", &err), err));
739
817
  CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err));
740
818
  CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err));
741
819
  GGML_LOG_CONT(".");
@@ -753,12 +831,17 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
753
831
  backend_ctx->program_glu =
754
832
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
755
833
 
756
- CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
757
- CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
758
- CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
759
- CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
760
- CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
761
- CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
834
+ CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
835
+ CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
836
+ CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
837
+ CL_CHECK((backend_ctx->kernel_swiglu_oai = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err));
838
+ CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
839
+ CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
840
+ CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
841
+ CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
842
+ CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
843
+ CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
844
+ CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
762
845
  GGML_LOG_CONT(".");
763
846
  }
764
847
 
@@ -916,6 +999,70 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
916
999
  GGML_LOG_CONT(".");
917
1000
  }
918
1001
 
1002
+ // mul_mv_q8_0_f32
1003
+ {
1004
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1005
+ const std::string kernel_src {
1006
+ #include "mul_mv_q8_0_f32.cl.h"
1007
+ };
1008
+ #else
1009
+ const std::string kernel_src = read_file("mul_mv_q8_0_f32.cl");
1010
+ #endif
1011
+ backend_ctx->program_mul_mv_q8_0_f32 =
1012
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1013
+
1014
+ CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32, "kernel_mul_mv_q8_0_f32", &err), err));
1015
+ GGML_LOG_CONT(".");
1016
+ }
1017
+
1018
+ // mul_mv_q8_0_f32_flat
1019
+ {
1020
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1021
+ const std::string kernel_src {
1022
+ #include "mul_mv_q8_0_f32_flat.cl.h"
1023
+ };
1024
+ #else
1025
+ const std::string kernel_src = read_file("mul_mv_q8_0_f32_flat.cl");
1026
+ #endif
1027
+ backend_ctx->program_mul_mv_q8_0_f32_flat =
1028
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1029
+
1030
+ CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32_flat, "kernel_mul_mv_q8_0_f32_flat", &err), err));
1031
+ GGML_LOG_CONT(".");
1032
+ }
1033
+
1034
+ // mul_mv_mxfp4_f32
1035
+ {
1036
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1037
+ const std::string kernel_src {
1038
+ #include "mul_mv_mxfp4_f32.cl.h"
1039
+ };
1040
+ #else
1041
+ const std::string kernel_src = read_file("mul_mv_mxfp4_f32.cl");
1042
+ #endif
1043
+ backend_ctx->program_mul_mv_mxfp4_f32 =
1044
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1045
+
1046
+ CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32, "kernel_mul_mv_mxfp4_f32", &err), err));
1047
+ GGML_LOG_CONT(".");
1048
+ }
1049
+
1050
+ // mul_mv_mxfp4_f32_flat
1051
+ {
1052
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1053
+ const std::string kernel_src {
1054
+ #include "mul_mv_mxfp4_f32_flat.cl.h"
1055
+ };
1056
+ #else
1057
+ const std::string kernel_src = read_file("mul_mv_mxfp4_f32_flat.cl");
1058
+ #endif
1059
+ backend_ctx->program_mul_mv_mxfp4_f32_flat =
1060
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1061
+
1062
+ CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32_flat, "kernel_mul_mv_mxfp4_f32_flat", &err), err));
1063
+ GGML_LOG_CONT(".");
1064
+ }
1065
+
919
1066
  // mul_mv_f16_f16
920
1067
  {
921
1068
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -996,6 +1143,54 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
996
1143
  GGML_LOG_CONT(".");
997
1144
  }
998
1145
 
1146
+ // mul_mat_f16_f32_tiled
1147
+ {
1148
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1149
+ const std::string kernel_src {
1150
+ #include "mul_mat_f16_f32.cl.h"
1151
+ };
1152
+ #else
1153
+ const std::string kernel_src = read_file("mul_mat_f16_f32.cl");
1154
+ #endif
1155
+ backend_ctx->program_mul_mat_f16_f32_tiled =
1156
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1157
+
1158
+ CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err));
1159
+ GGML_LOG_CONT(".");
1160
+ }
1161
+
1162
+ // mul_mm_f32_f32_l4_lm
1163
+ {
1164
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1165
+ const std::string kernel_src {
1166
+ #include "mul_mm_f32_f32_l4_lm.cl.h"
1167
+ };
1168
+ #else
1169
+ const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl");
1170
+ #endif
1171
+ backend_ctx->program_mul_mm_f32_f32_l4_lm =
1172
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1173
+
1174
+ CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err));
1175
+ GGML_LOG_CONT(".");
1176
+ }
1177
+
1178
+ // mul_mm_f16_f32_l4_lm
1179
+ {
1180
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1181
+ const std::string kernel_src {
1182
+ #include "mul_mm_f16_f32_l4_lm.cl.h"
1183
+ };
1184
+ #else
1185
+ const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl");
1186
+ #endif
1187
+ backend_ctx->program_mul_mm_f16_f32_l4_lm =
1188
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1189
+
1190
+ CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err));
1191
+ GGML_LOG_CONT(".");
1192
+ }
1193
+
999
1194
  // mul
1000
1195
  {
1001
1196
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1008,8 +1203,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1008
1203
  backend_ctx->program_mul =
1009
1204
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1010
1205
 
1011
- CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
1012
- CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
1206
+ CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
1207
+ CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
1208
+ CL_CHECK((backend_ctx->kernel_mul_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_f16", &err), err));
1209
+ CL_CHECK((backend_ctx->kernel_mul_row_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row_f16", &err), err));
1013
1210
  GGML_LOG_CONT(".");
1014
1211
  }
1015
1212
 
@@ -1025,7 +1222,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1025
1222
  backend_ctx->program_norm =
1026
1223
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1027
1224
 
1028
- CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
1225
+ CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
1226
+ CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err));
1029
1227
  GGML_LOG_CONT(".");
1030
1228
  }
1031
1229
 
@@ -1057,7 +1255,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1057
1255
  backend_ctx->program_rms_norm =
1058
1256
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1059
1257
 
1060
- CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
1258
+ CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
1259
+ CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
1061
1260
  GGML_LOG_CONT(".");
1062
1261
  }
1063
1262
 
@@ -1181,6 +1380,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1181
1380
  GGML_LOG_CONT(".");
1182
1381
  }
1183
1382
 
1383
+ // flash_attn
1384
+ {
1385
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1386
+ const std::string kernel_src_f16 {
1387
+ #include "flash_attn_f16.cl.h"
1388
+ };
1389
+ const std::string kernel_src_f32 {
1390
+ #include "flash_attn_f32.cl.h"
1391
+ };
1392
+ const std::string kernel_src_f32_f16 {
1393
+ #include "flash_attn_f32_f16.cl.h"
1394
+ };
1395
+ #else
1396
+ const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
1397
+ const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
1398
+ const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
1399
+ #endif
1400
+
1401
+ if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
1402
+ const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
1403
+ { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
1404
+ {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
1405
+ {192, 192, 16, 16}, {256, 256, 16, 16},
1406
+ };
1407
+
1408
+ for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
1409
+ const int dk = fa_dims[i].dk;
1410
+ const int dv = fa_dims[i].dv;
1411
+ const int bm = fa_dims[i].bm;
1412
+ const int bn = fa_dims[i].bn;
1413
+ std::string OPTS = compile_opts +
1414
+ " -D DK=" + std::to_string(dk) +
1415
+ " -D DV=" + std::to_string(dv) +
1416
+ " -D BLOCK_M=" + std::to_string(bm) +
1417
+ " -D BLOCK_N=" + std::to_string(bn);
1418
+
1419
+ cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
1420
+ cl_kernel k_f16, k_f16_q1;
1421
+ CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
1422
+ CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
1423
+ backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
1424
+ backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
1425
+ CL_CHECK(clReleaseProgram(prog_f16));
1426
+
1427
+ cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
1428
+ cl_kernel k_f32, k_f32_q1;
1429
+ CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
1430
+ CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
1431
+ backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
1432
+ backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
1433
+ CL_CHECK(clReleaseProgram(prog_f32));
1434
+
1435
+ cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
1436
+ cl_kernel k_f32_f16, k_f32_f16_q1;
1437
+ CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
1438
+ CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
1439
+ backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
1440
+ backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
1441
+ CL_CHECK(clReleaseProgram(prog_f32_f16));
1442
+
1443
+ backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
1444
+ backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
1445
+ }
1446
+ GGML_LOG_CONT(".");
1447
+ }
1448
+ }
1449
+
1184
1450
  // argsort
1185
1451
  {
1186
1452
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1206,11 +1472,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1206
1472
  #else
1207
1473
  const std::string kernel_src = read_file("div.cl");
1208
1474
  #endif
1475
+ std::string compile_opts = std::string("-cl-std=") + opencl_c_std +
1476
+ " -cl-mad-enable -cl-finite-math-only ";
1477
+
1209
1478
  backend_ctx->program_div =
1210
1479
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1211
1480
 
1212
- CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
1213
- CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
1481
+ CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
1482
+ CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
1483
+ CL_CHECK((backend_ctx->kernel_div_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_f16", &err), err));
1484
+ CL_CHECK((backend_ctx->kernel_div_row_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_row_f16", &err), err));
1214
1485
  GGML_LOG_CONT(".");
1215
1486
  }
1216
1487
 
@@ -1226,8 +1497,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1226
1497
  backend_ctx->program_sub =
1227
1498
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1228
1499
 
1229
- CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
1230
- CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
1500
+ CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
1501
+ CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
1502
+ CL_CHECK((backend_ctx->kernel_sub_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_f16", &err), err));
1503
+ CL_CHECK((backend_ctx->kernel_sub_row_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row_f16", &err), err));
1231
1504
  GGML_LOG_CONT(".");
1232
1505
  }
1233
1506
 
@@ -1276,7 +1549,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1276
1549
  backend_ctx->program_group_norm =
1277
1550
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1278
1551
 
1279
- CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
1552
+ CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
1553
+ CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err));
1280
1554
  GGML_LOG_CONT(".");
1281
1555
  }
1282
1556
 
@@ -1424,6 +1698,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1424
1698
  }
1425
1699
  }
1426
1700
 
1701
+ // set_rows
1702
+ {
1703
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1704
+ const std::string kernel_src {
1705
+ #include "set_rows.cl.h"
1706
+ };
1707
+ #else
1708
+ const std::string kernel_src = read_file("set_rows.cl");
1709
+ #endif
1710
+ backend_ctx->program_set_rows =
1711
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1712
+
1713
+ CL_CHECK((backend_ctx->kernel_set_rows_f32_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i64", &err), err));
1714
+ CL_CHECK((backend_ctx->kernel_set_rows_f32_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i32", &err), err));
1715
+ CL_CHECK((backend_ctx->kernel_set_rows_f16_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i64", &err), err));
1716
+ CL_CHECK((backend_ctx->kernel_set_rows_f16_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i32", &err), err));
1717
+ GGML_LOG_CONT(".");
1718
+ }
1719
+
1720
+ // conv2d
1721
+ {
1722
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1723
+ const std::string kernel_src {
1724
+ #include "conv2d.cl.h"
1725
+ };
1726
+ const std::string kernel_src_f16_f32 {
1727
+ #include "conv2d_f16_f32.cl.h"
1728
+ };
1729
+ #else
1730
+ const std::string kernel_src = read_file("conv2d.cl");
1731
+ const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
1732
+ #endif
1733
+ if (!kernel_src.empty()) {
1734
+ backend_ctx->program_conv_2d_f16 =
1735
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
1736
+ CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
1737
+ GGML_LOG_CONT(".");
1738
+ backend_ctx->program_conv_2d_f32 =
1739
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1740
+ CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
1741
+ GGML_LOG_CONT(".");
1742
+ } else {
1743
+ GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
1744
+ backend_ctx->program_conv_2d_f16 = nullptr;
1745
+ backend_ctx->kernel_conv_2d_f16 = nullptr;
1746
+ backend_ctx->program_conv_2d_f32 = nullptr;
1747
+ backend_ctx->kernel_conv_2d_f32 = nullptr;
1748
+ }
1749
+ if (!kernel_src_f16_f32.empty()) {
1750
+ backend_ctx->program_conv_2d_f16_f32 =
1751
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
1752
+ CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
1753
+ GGML_LOG_CONT(".");
1754
+ } else {
1755
+ GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
1756
+ backend_ctx->program_conv_2d_f16_f32 = nullptr;
1757
+ backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
1758
+ }
1759
+ }
1760
+
1427
1761
  // mul_mv_id_q4_0_f32_8x_flat
1428
1762
  {
1429
1763
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1440,6 +1774,70 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1440
1774
  GGML_LOG_CONT(".");
1441
1775
  }
1442
1776
 
1777
+ // mul_mv_id_q8_0_f32
1778
+ {
1779
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1780
+ const std::string kernel_src {
1781
+ #include "mul_mv_id_q8_0_f32.cl.h"
1782
+ };
1783
+ #else
1784
+ const std::string kernel_src = read_file("mul_mv_id_q8_0_f32.cl");
1785
+ #endif
1786
+ backend_ctx->program_mul_mv_id_q8_0_f32 =
1787
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1788
+
1789
+ CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32, "kernel_mul_mv_id_q8_0_f32", &err), err));
1790
+ GGML_LOG_CONT(".");
1791
+ }
1792
+
1793
+ // mul_mv_id_q8_0_f32_flat
1794
+ {
1795
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1796
+ const std::string kernel_src {
1797
+ #include "mul_mv_id_q8_0_f32_flat.cl.h"
1798
+ };
1799
+ #else
1800
+ const std::string kernel_src = read_file("mul_mv_id_q8_0_f32_flat.cl");
1801
+ #endif
1802
+ backend_ctx->program_mul_mv_id_q8_0_f32_flat =
1803
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1804
+
1805
+ CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32_flat, "kernel_mul_mv_id_q8_0_f32_flat", &err), err));
1806
+ GGML_LOG_CONT(".");
1807
+ }
1808
+
1809
+ // mul_mv_id_mxfp4_f32
1810
+ {
1811
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1812
+ const std::string kernel_src {
1813
+ #include "mul_mv_id_mxfp4_f32.cl.h"
1814
+ };
1815
+ #else
1816
+ const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32.cl");
1817
+ #endif
1818
+ backend_ctx->program_mul_mv_id_mxfp4_f32 =
1819
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1820
+
1821
+ CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32, "kernel_mul_mv_id_mxfp4_f32", &err), err));
1822
+ GGML_LOG_CONT(".");
1823
+ }
1824
+
1825
+ // mul_mv_id_mxfp4_f32_flat
1826
+ {
1827
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1828
+ const std::string kernel_src {
1829
+ #include "mul_mv_id_mxfp4_f32_flat.cl.h"
1830
+ };
1831
+ #else
1832
+ const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32_flat.cl");
1833
+ #endif
1834
+ backend_ctx->program_mul_mv_id_mxfp4_f32_flat =
1835
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1836
+
1837
+ CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32_flat, "kernel_mul_mv_id_mxfp4_f32_flat", &err), err));
1838
+ GGML_LOG_CONT(".");
1839
+ }
1840
+
1443
1841
  // Adreno kernels
1444
1842
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
1445
1843
  // transpose
@@ -1457,6 +1855,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1457
1855
  CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err));
1458
1856
  CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err));
1459
1857
  CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err));
1858
+ CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err));
1460
1859
  GGML_LOG_CONT(".");
1461
1860
  }
1462
1861
 
@@ -1895,8 +2294,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
1895
2294
 
1896
2295
  backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version);
1897
2296
  backend_ctx->has_vector_subgroup_broadcast =
1898
- backend_ctx->adreno_cl_compiler_version.major >= 47 ||
1899
- backend_ctx->adreno_cl_compiler_version.major == 17;
2297
+ (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) ||
2298
+ (backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17);
1900
2299
  GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n",
1901
2300
  backend_ctx->has_vector_subgroup_broadcast ? "true" : "false");
1902
2301
 
@@ -1933,6 +2332,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
1933
2332
  clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
1934
2333
  GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
1935
2334
 
2335
+ clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
2336
+ GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
2337
+
1936
2338
  // Check SVM.
1937
2339
  cl_device_svm_capabilities svm_caps;
1938
2340
  CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));
@@ -2009,6 +2411,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
2009
2411
  CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
2010
2412
  #endif // GGML_OPENCL_USE_ADRENO_KERNELS
2011
2413
 
2414
+ backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
2415
+
2012
2416
  dev_ctx->backend_ctx = backend_ctx.release();
2013
2417
  return dev_ctx->backend_ctx;
2014
2418
  }
@@ -2098,6 +2502,84 @@ struct ggml_tensor_extra_cl_q4_0 {
2098
2502
  }
2099
2503
  };
2100
2504
 
2505
+ struct ggml_tensor_extra_cl_mxfp4 {
2506
+ // Quantized values.
2507
+ cl_mem q = nullptr;
2508
+ // Quantized values in image1d_buffer_t.
2509
+ cl_mem q_img = nullptr;
2510
+ // Scales in E8M0.
2511
+ cl_mem e = nullptr;
2512
+ // Scales in image1d_buffer_t.
2513
+ cl_mem e_img = nullptr;
2514
+ // Size of quantized values.
2515
+ size_t size_q = 0;
2516
+ // Size of scales.
2517
+ size_t size_e = 0;
2518
+
2519
+ ~ggml_tensor_extra_cl_mxfp4() {
2520
+ reset();
2521
+ }
2522
+
2523
+ void reset() {
2524
+ // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
2525
+ // They must be properly released so that the original buffer can be
2526
+ // properly released to avoid memory leak.
2527
+ if (q != nullptr) {
2528
+ CL_CHECK(clReleaseMemObject(q));
2529
+ q = nullptr;
2530
+ }
2531
+ if (e != nullptr) {
2532
+ CL_CHECK(clReleaseMemObject(e));
2533
+ e = nullptr;
2534
+ }
2535
+ if (q != nullptr) {
2536
+ CL_CHECK(clReleaseMemObject(q_img));
2537
+ q = nullptr;
2538
+ }
2539
+ // Currently, q_img and d_img are not used. They can be image1d_buffer_t
2540
+ // that wraps around q and d to utilize image access path.
2541
+ q_img = nullptr;
2542
+ e_img = nullptr;
2543
+ size_q = 0;
2544
+ size_e = 0;
2545
+ }
2546
+ };
2547
+
2548
+ struct ggml_tensor_extra_cl_q8_0 {
2549
+ cl_mem q = nullptr;
2550
+ cl_mem q_img = nullptr;
2551
+
2552
+ cl_mem d = nullptr;
2553
+ cl_mem d_img = nullptr;
2554
+
2555
+ size_t size_q = 0;
2556
+ size_t size_d = 0;
2557
+
2558
+ ~ggml_tensor_extra_cl_q8_0() {
2559
+ reset();
2560
+ }
2561
+
2562
+ void reset() {
2563
+ // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
2564
+ // They must be properly released so that the original buffer can be
2565
+ // properly released to avoid memory leak.
2566
+ if (q != nullptr) {
2567
+ CL_CHECK(clReleaseMemObject(q));
2568
+ q = nullptr;
2569
+ }
2570
+ if (d != nullptr) {
2571
+ CL_CHECK(clReleaseMemObject(d));
2572
+ d = nullptr;
2573
+ }
2574
+ // Currently, q_img and d_img are not used. They can be image1d_buffer_t
2575
+ // that wraps around q and d to utilize image access path.
2576
+ q_img = nullptr;
2577
+ d_img = nullptr;
2578
+ size_q = 0;
2579
+ size_d = 0;
2580
+ }
2581
+ };
2582
+
2101
2583
  //------------------------------------------------------------------------------
2102
2584
  // Backend API
2103
2585
  //------------------------------------------------------------------------------
@@ -2178,37 +2660,127 @@ static void sync_with_other_backends(ggml_backend_t backend) {
2178
2660
  sync_with_other_backends(backend_ctx);
2179
2661
  }
2180
2662
 
2181
- static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2182
- for (int i = 0; i < cgraph->n_nodes; i++) {
2183
- ggml_tensor * node = cgraph->nodes[i];
2663
+ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2664
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2665
+ return false;
2666
+ }
2184
2667
 
2185
- // NOTE: this may oversynchronize by synchronizing with
2186
- // backends/devices which don't compute 'cgraph's
2187
- // dependencies.
2188
- sync_with_other_backends(backend);
2668
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2669
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2670
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2189
2671
 
2190
- if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2191
- continue;
2672
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2673
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2674
+
2675
+ // rms_norm only supports f32
2676
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
2677
+ mul->src[1]->type != GGML_TYPE_F32 ||
2678
+ mul->type != GGML_TYPE_F32) {
2679
+ return false;
2192
2680
  }
2193
2681
 
2194
- bool ok = ggml_cl_compute_forward(backend, node);
2195
- if (!ok) {
2196
- GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2682
+ // if rms_norm is the B operand, then we don't handle broadcast
2683
+ if (rms_norm == mul->src[1] &&
2684
+ !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2685
+ return false;
2197
2686
  }
2198
- GGML_ASSERT(ok);
2199
- }
2200
2687
 
2201
- return GGML_STATUS_SUCCESS;
2202
- }
2688
+ // rms_norm assumes contiguous rows
2689
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2690
+ return false;
2691
+ }
2692
+ } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
2693
+ const ggml_tensor *norm = cgraph->nodes[node_idx];
2694
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2695
+ const ggml_tensor *add = cgraph->nodes[node_idx+2];
2696
+ const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0];
2697
+ const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
2698
+
2699
+ // norm fusion only supports F32
2700
+ if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2701
+ return false;
2702
+ }
2203
2703
 
2204
- static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2205
- GGML_UNUSED(dev);
2704
+ if (norm->src[0]->ne[0] % 4 != 0) {
2705
+ return false;
2706
+ }
2206
2707
 
2207
- switch (op->op) {
2208
- case GGML_OP_NONE:
2209
- return true;
2210
- case GGML_OP_GET_ROWS:
2211
- switch (op->src[0]->type) {
2708
+ if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
2709
+ return false;
2710
+ }
2711
+ } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
2712
+ const ggml_tensor *gn = cgraph->nodes[node_idx];
2713
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2714
+ const ggml_tensor *add = cgraph->nodes[node_idx+2];
2715
+ const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0];
2716
+ const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
2717
+
2718
+ if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2719
+ return false;
2720
+ }
2721
+
2722
+ if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
2723
+ return false;
2724
+ }
2725
+ }
2726
+
2727
+ return true;
2728
+ }
2729
+
2730
+ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2731
+ static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
2732
+ static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
2733
+
2734
+ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2735
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
2736
+
2737
+ for (int i = 0; i < cgraph->n_nodes; i++) {
2738
+ ggml_tensor * node = cgraph->nodes[i];
2739
+
2740
+ // NOTE: this may oversynchronize by synchronizing with
2741
+ // backends/devices which don't compute 'cgraph's
2742
+ // dependencies.
2743
+ sync_with_other_backends(backend);
2744
+
2745
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2746
+ continue;
2747
+ }
2748
+
2749
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2750
+ ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
2751
+ i += 2;
2752
+ continue;
2753
+ }
2754
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2755
+ ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
2756
+ i += 2;
2757
+ continue;
2758
+ }
2759
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2760
+ ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
2761
+ i++;
2762
+ continue;
2763
+ }
2764
+
2765
+ bool ok = ggml_cl_compute_forward(backend, node);
2766
+ if (!ok) {
2767
+ GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2768
+ }
2769
+ GGML_ASSERT(ok);
2770
+ }
2771
+
2772
+ return GGML_STATUS_SUCCESS;
2773
+ }
2774
+
2775
+ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2776
+ ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
2777
+ ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx;
2778
+
2779
+ switch (op->op) {
2780
+ case GGML_OP_NONE:
2781
+ return true;
2782
+ case GGML_OP_GET_ROWS:
2783
+ switch (op->src[0]->type) {
2212
2784
  case GGML_TYPE_F32:
2213
2785
  case GGML_TYPE_F16:
2214
2786
  return true;
@@ -2222,6 +2794,22 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2222
2794
  default:
2223
2795
  return false;
2224
2796
  }
2797
+ case GGML_OP_SET_ROWS:
2798
+ {
2799
+ // TODO: add support
2800
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14274
2801
+ #pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
2802
+ if (op->src[0]->type != GGML_TYPE_F32) {
2803
+ return false;
2804
+ }
2805
+ switch (op->type) {
2806
+ case GGML_TYPE_F16:
2807
+ case GGML_TYPE_F32:
2808
+ return (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
2809
+ default:
2810
+ return false;
2811
+ }
2812
+ }
2225
2813
  case GGML_OP_CPY:
2226
2814
  case GGML_OP_DUP:
2227
2815
  case GGML_OP_CONT:
@@ -2245,17 +2833,30 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2245
2833
  default:
2246
2834
  return false;
2247
2835
  }
2248
- case GGML_OP_ADD:
2249
2836
  case GGML_OP_SCALE:
2837
+ return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2838
+ case GGML_OP_ADD:
2839
+ if (op->type == GGML_TYPE_F16) {
2840
+ const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
2841
+ const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
2842
+ if (src0_ok && src1_ok) {
2843
+ return true;
2844
+ }
2845
+ }
2250
2846
  case GGML_OP_MUL:
2251
2847
  case GGML_OP_DIV:
2252
2848
  case GGML_OP_SUB:
2849
+ return (op->src[0]->type == op->src[1]->type) &&
2850
+ (op->src[0]->type == op->type) &&
2851
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
2852
+ case GGML_OP_ADD_ID:
2253
2853
  return op->src[0]->type == GGML_TYPE_F32;
2254
2854
  case GGML_OP_UNARY:
2255
2855
  switch (ggml_get_unary_op(op)) {
2256
2856
  case GGML_UNARY_OP_GELU:
2257
2857
  case GGML_UNARY_OP_SILU:
2258
2858
  case GGML_UNARY_OP_RELU:
2859
+ case GGML_UNARY_OP_GELU_ERF:
2259
2860
  case GGML_UNARY_OP_GELU_QUICK:
2260
2861
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
2261
2862
  case GGML_UNARY_OP_SIGMOID:
@@ -2271,6 +2872,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2271
2872
  case GGML_GLU_OP_GEGLU:
2272
2873
  case GGML_GLU_OP_REGLU:
2273
2874
  case GGML_GLU_OP_SWIGLU:
2875
+ case GGML_GLU_OP_SWIGLU_OAI:
2876
+ case GGML_GLU_OP_GEGLU_ERF:
2877
+ case GGML_GLU_OP_GEGLU_QUICK:
2274
2878
  return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
2275
2879
  default:
2276
2880
  return false;
@@ -2279,15 +2883,22 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2279
2883
  return op->src[0]->type == GGML_TYPE_F32;
2280
2884
  case GGML_OP_SOFT_MAX:
2281
2885
  case GGML_OP_NORM:
2282
- case GGML_OP_RMS_NORM:
2283
2886
  return true;
2887
+ case GGML_OP_RMS_NORM:
2888
+ return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
2284
2889
  case GGML_OP_REPEAT:
2285
2890
  return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
2286
2891
  case GGML_OP_PAD:
2287
2892
  return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
2288
- op->src[0]->ne[3] == 1 && op->ne[3] == 1;
2893
+ op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
2894
+ (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
2895
+ (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
2289
2896
  case GGML_OP_UPSCALE:
2290
2897
  return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2898
+ case GGML_OP_CONV_2D:
2899
+ return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
2900
+ (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2901
+ (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
2291
2902
  case GGML_OP_CONCAT:
2292
2903
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2293
2904
  case GGML_OP_TIMESTEP_EMBEDDING:
@@ -2299,13 +2910,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2299
2910
  return true;
2300
2911
  } else if (op->src[0]->type == GGML_TYPE_F32) {
2301
2912
  return op->src[1]->type == GGML_TYPE_F32;
2302
- } else if (op->src[0]->type == GGML_TYPE_Q4_0 ||
2913
+ } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
2303
2914
  op->src[0]->type == GGML_TYPE_Q6_K) {
2304
2915
  return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2916
+ } else if (op->src[0]->type == GGML_TYPE_Q8_0) {
2917
+ return op->src[1]->type == GGML_TYPE_F32;
2305
2918
  }
2306
2919
  return false;
2307
2920
  case GGML_OP_MUL_MAT_ID:
2308
- if (op->src[0]->type == GGML_TYPE_Q4_0) {
2921
+ if (op->src[0]->type == GGML_TYPE_Q4_0 ||
2922
+ op->src[0]->type == GGML_TYPE_Q8_0 ||
2923
+ op->src[0]->type == GGML_TYPE_MXFP4) {
2309
2924
  if (op->src[1]->type == GGML_TYPE_F32) {
2310
2925
  return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2311
2926
  }
@@ -2340,10 +2955,54 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2340
2955
  }
2341
2956
  case GGML_OP_IM2COL:
2342
2957
  return true;
2343
- case GGML_OP_ARGSORT:
2344
- return op->src[0]->type == GGML_TYPE_F32;
2958
+ case GGML_OP_ARGSORT: {
2959
+ cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32;
2960
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
2961
+
2962
+ int cols = 1;
2963
+ while (cols < op->ne[0]) {
2964
+ cols *= 2;
2965
+ }
2966
+
2967
+ return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
2968
+ }
2345
2969
  case GGML_OP_SUM_ROWS:
2346
2970
  return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2971
+ case GGML_OP_FLASH_ATTN_EXT:
2972
+ {
2973
+ const ggml_tensor * q = op->src[0];
2974
+ const ggml_tensor * k = op->src[1];
2975
+ const ggml_tensor * v = op->src[2];
2976
+
2977
+ const int dk = q->ne[0];
2978
+ const int dv = v->ne[0];
2979
+
2980
+ const struct { int dk; int dv; } supported_dims[] = {
2981
+ { 40, 40}, { 64, 64}, { 80, 80}, { 96, 96},
2982
+ {112, 112}, {128, 128}, {192, 128},
2983
+ {192, 192}, {256, 256},
2984
+ };
2985
+
2986
+ bool dims_supported = false;
2987
+ for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
2988
+ if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
2989
+ dims_supported = true;
2990
+ break;
2991
+ }
2992
+ }
2993
+ if (!dims_supported) {
2994
+ return false;
2995
+ }
2996
+
2997
+ const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
2998
+ v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2999
+ const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
3000
+ v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
3001
+ const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
3002
+ v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;
3003
+
3004
+ return is_f32_f32 || is_f16_f16 || is_f32_f16;
3005
+ }
2347
3006
  default:
2348
3007
  return false;
2349
3008
  }
@@ -2371,6 +3030,7 @@ static ggml_backend_i ggml_backend_opencl_i = {
2371
3030
  /* .graph_compute = */ ggml_backend_opencl_graph_compute,
2372
3031
  /* .event_record = */ NULL,
2373
3032
  /* .event_wait = */ NULL,
3033
+ /* .graph_optimize = */ NULL,
2374
3034
  };
2375
3035
 
2376
3036
  ggml_backend_t ggml_backend_opencl_init(void) {
@@ -2378,10 +3038,10 @@ ggml_backend_t ggml_backend_opencl_init(void) {
2378
3038
  ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
2379
3039
 
2380
3040
  ggml_backend_t backend = new ggml_backend {
2381
- /* .guid = */ ggml_backend_opencl_guid(),
2382
- /* .interface = */ ggml_backend_opencl_i,
2383
- /* .device = */ dev,
2384
- /* .context = */ backend_ctx
3041
+ /* .guid = */ ggml_backend_opencl_guid(),
3042
+ /* .iface = */ ggml_backend_opencl_i,
3043
+ /* .device = */ dev,
3044
+ /* .context = */ backend_ctx
2385
3045
  };
2386
3046
 
2387
3047
  return backend;
@@ -2426,6 +3086,18 @@ struct ggml_backend_opencl_buffer_context {
2426
3086
  for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
2427
3087
  delete e;
2428
3088
  }
3089
+ for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) {
3090
+ delete e;
3091
+ }
3092
+ for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
3093
+ delete e;
3094
+ }
3095
+ for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0) {
3096
+ delete e;
3097
+ }
3098
+ for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
3099
+ delete e;
3100
+ }
2429
3101
  }
2430
3102
 
2431
3103
  ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
@@ -2458,6 +3130,36 @@ struct ggml_backend_opencl_buffer_context {
2458
3130
  return extra;
2459
3131
  }
2460
3132
 
3133
+ ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {
3134
+ ggml_tensor_extra_cl_mxfp4 * extra;
3135
+ if (temp_tensor_extras_mxfp4.empty()) {
3136
+ extra = new ggml_tensor_extra_cl_mxfp4();
3137
+ } else {
3138
+ extra = temp_tensor_extras_mxfp4.back();
3139
+ temp_tensor_extras_mxfp4.pop_back();
3140
+ }
3141
+
3142
+ temp_tensor_extras_mxfp4_in_use.push_back(extra);
3143
+
3144
+ extra->reset();
3145
+ return extra;
3146
+ }
3147
+
3148
+ ggml_tensor_extra_cl_q8_0 * ggml_opencl_alloc_temp_tensor_extra_q8_0() {
3149
+ ggml_tensor_extra_cl_q8_0 * extra;
3150
+ if (temp_tensor_extras_q8_0.empty()) {
3151
+ extra = new ggml_tensor_extra_cl_q8_0();
3152
+ } else {
3153
+ extra = temp_tensor_extras_q8_0.back();
3154
+ temp_tensor_extras_q8_0.pop_back();
3155
+ }
3156
+
3157
+ temp_tensor_extras_q8_0_in_use.push_back(extra);
3158
+
3159
+ extra->reset();
3160
+ return extra;
3161
+ }
3162
+
2461
3163
  void reset() {
2462
3164
  for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
2463
3165
  temp_tensor_extras.push_back(e);
@@ -2468,6 +3170,16 @@ struct ggml_backend_opencl_buffer_context {
2468
3170
  temp_tensor_extras_q4_0.push_back(e);
2469
3171
  }
2470
3172
  temp_tensor_extras_q4_0_in_use.clear();
3173
+
3174
+ for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
3175
+ temp_tensor_extras_mxfp4.push_back(e);
3176
+ }
3177
+ temp_tensor_extras_mxfp4_in_use.clear();
3178
+
3179
+ for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
3180
+ temp_tensor_extras_q8_0.push_back(e);
3181
+ }
3182
+ temp_tensor_extras_q8_0_in_use.clear();
2471
3183
  }
2472
3184
 
2473
3185
  // Pools for extras. Available extras are in `temp_tensor_extras`. Extras
@@ -2479,6 +3191,10 @@ struct ggml_backend_opencl_buffer_context {
2479
3191
  std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
2480
3192
  std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
2481
3193
  std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
3194
+ std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4;
3195
+ std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
3196
+ std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
3197
+ std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;
2482
3198
 
2483
3199
  // The buffer_context is initially created by ggml_backend_buft_alloc_buffer
2484
3200
  // before any tensor is initialized (at the beginning of alloc_tensor_range).
@@ -2691,7 +3407,10 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
2691
3407
  // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
2692
3408
  CL_CHECK(err);
2693
3409
 
2694
- // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float);
3410
+ bool K_tile_trans = true;
3411
+ if ((K / 32) % 4 != 0){
3412
+ K_tile_trans =false;
3413
+ }
2695
3414
  size_t d_size_bytes = M * (K / 32) * 2;
2696
3415
  region.origin = 0;
2697
3416
  region.size = d_size_bytes;
@@ -2732,10 +3451,15 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
2732
3451
  qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
2733
3452
  CL_CHECK(err);
2734
3453
 
2735
- img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };
2736
3454
  memset(&img_desc_1d, 0, sizeof(img_desc_1d));
3455
+ if (K_tile_trans) {
3456
+ img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };
3457
+ img_desc_1d.image_width = M * K / 32 / 4;
3458
+ } else {
3459
+ img_fmt_1d = { CL_R, CL_HALF_FLOAT };
3460
+ img_desc_1d.image_width = M * K / 32;
3461
+ }
2737
3462
  img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
2738
- img_desc_1d.image_width = M * K / 32 / 4;
2739
3463
  img_desc_1d.buffer = extra->d;
2740
3464
  d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
2741
3465
  CL_CHECK(err);
@@ -2771,6 +3495,10 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
2771
3495
  int width_s = K / 32 / 4;
2772
3496
 
2773
3497
  kernel = backend_ctx->kernel_transpose_16;
3498
+ if (!K_tile_trans) {
3499
+ kernel = backend_ctx->kernel_transpose_16_4x1;
3500
+ width_s = K / 32;
3501
+ }
2774
3502
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));
2775
3503
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));
2776
3504
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));
@@ -2809,6 +3537,135 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
2809
3537
  }
2810
3538
  #endif // GGML_OPENCL_USE_ADRENO_KERNELS
2811
3539
 
3540
+ return;
3541
+
3542
+ }
3543
+ if (tensor->type == GGML_TYPE_MXFP4) {
3544
+ ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
3545
+ GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
3546
+
3547
+ // Allocate the new extra and create aliases from the original.
3548
+ ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
3549
+ ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4();
3550
+
3551
+ size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char);
3552
+ size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
3553
+ GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
3554
+
3555
+ cl_int err;
3556
+ cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
3557
+ ggml_nbytes(tensor), NULL, &err);
3558
+ CL_CHECK(err);
3559
+ CL_CHECK(clEnqueueWriteBuffer(
3560
+ queue, data_device, CL_TRUE, 0,
3561
+ ggml_nbytes(tensor), data, 0, NULL, NULL));
3562
+
3563
+ // The original tensor memory is divided into scales and quants, i.e.,
3564
+ // we first store scales, then quants.
3565
+ cl_buffer_region region;
3566
+
3567
+ // Create subbuffer for scales.
3568
+ region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
3569
+ region.size = size_e;
3570
+ extra->e = clCreateSubBuffer(
3571
+ extra_orig->data_device, CL_MEM_READ_WRITE,
3572
+ CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
3573
+ CL_CHECK(err);
3574
+ auto previous_origin = region.origin;
3575
+
3576
+ // Create subbuffer for quants.
3577
+ region.origin = align_to(previous_origin + size_e, backend_ctx->alignment);
3578
+ region.size = size_q;
3579
+ extra->q = clCreateSubBuffer(
3580
+ extra_orig->data_device, CL_MEM_READ_WRITE,
3581
+ CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
3582
+ CL_CHECK(err);
3583
+
3584
+ cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
3585
+
3586
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
3587
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
3588
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
3589
+
3590
+ size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3591
+ size_t local_work_size[] = {64, 1, 1};
3592
+
3593
+ cl_event evt;
3594
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
3595
+ CL_CHECK(clWaitForEvents(1, &evt));
3596
+ CL_CHECK(clReleaseMemObject(data_device));
3597
+
3598
+ // Create image for Q
3599
+ cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32};
3600
+ cl_image_desc img_desc_q = {
3601
+ CL_MEM_OBJECT_IMAGE1D_BUFFER,
3602
+ static_cast<size_t>(ggml_nelements(tensor)/32*2),
3603
+ 0, 0, 0, 0, 0, 0, 0,
3604
+ { extra->q }
3605
+ };
3606
+ extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
3607
+
3608
+ tensor->extra = extra;
3609
+
3610
+ return;
3611
+ }
3612
+ if (tensor->type == GGML_TYPE_Q8_0) {
3613
+ ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
3614
+ GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
3615
+
3616
+ // Allocate the new extra and create aliases from the original.
3617
+ ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
3618
+ ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0();
3619
+
3620
+ size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
3621
+ size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char));
3622
+ GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
3623
+
3624
+ cl_int err;
3625
+ cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
3626
+ ggml_nbytes(tensor), NULL, &err);
3627
+ CL_CHECK(err);
3628
+ CL_CHECK(clEnqueueWriteBuffer(
3629
+ queue, data_device, CL_TRUE, 0,
3630
+ ggml_nbytes(tensor), data, 0, NULL, NULL));
3631
+
3632
+ // The original tensor memory is divided into scales and quants, i.e.,
3633
+ // we first store scales, then quants.
3634
+ cl_buffer_region region;
3635
+
3636
+ // Create subbuffer for scales.
3637
+ region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
3638
+ region.size = size_d;
3639
+ extra->d = clCreateSubBuffer(
3640
+ extra_orig->data_device, CL_MEM_READ_WRITE,
3641
+ CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
3642
+ CL_CHECK(err);
3643
+ auto previous_origin = region.origin;
3644
+
3645
+ // Create subbuffer for quants.
3646
+ region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
3647
+ region.size = size_q;
3648
+ extra->q = clCreateSubBuffer(
3649
+ extra_orig->data_device, CL_MEM_READ_WRITE,
3650
+ CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
3651
+ CL_CHECK(err);
3652
+
3653
+ cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0;
3654
+
3655
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
3656
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
3657
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
3658
+
3659
+ size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3660
+ size_t local_work_size[] = {64, 1, 1};
3661
+
3662
+ cl_event evt;
3663
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
3664
+ CL_CHECK(clWaitForEvents(1, &evt));
3665
+ CL_CHECK(clReleaseMemObject(data_device));
3666
+
3667
+ tensor->extra = extra;
3668
+
2812
3669
  return;
2813
3670
  }
2814
3671
  #endif // GGML_OPENCL_SOA_Q
@@ -2866,25 +3723,76 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
2866
3723
  size, data, 0, NULL, NULL));
2867
3724
  CL_CHECK(clReleaseMemObject(data_device));
2868
3725
  return;
2869
- }
2870
- #endif // GGML_OPENCL_SOA_Q
3726
+ } else if (tensor->type == GGML_TYPE_MXFP4) {
3727
+ ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
2871
3728
 
2872
- ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
3729
+ cl_int err;
3730
+ cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
3731
+ ggml_nbytes(tensor), NULL, &err);
3732
+ CL_CHECK(err);
2873
3733
 
2874
- CL_CHECK(clEnqueueReadBuffer(
2875
- queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset,
2876
- size, data, 0, NULL, NULL));
3734
+ cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
3735
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
3736
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
3737
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
2877
3738
 
2878
- GGML_UNUSED(buffer);
2879
- }
3739
+ size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3740
+ size_t local_work_size[] = {1, 1, 1};
2880
3741
 
2881
- static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2882
- ggml_backend_dev_t dev = buffer->buft->device;
2883
- ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
2884
- cl_command_queue queue = backend_ctx->queue;
3742
+ cl_event evt;
3743
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
3744
+ global_work_size, local_work_size, 0, NULL, &evt));
3745
+ CL_CHECK(clWaitForEvents(1, &evt));
3746
+ CL_CHECK(clEnqueueReadBuffer(
3747
+ queue, data_device, CL_TRUE, offset,
3748
+ size, data, 0, NULL, NULL));
3749
+ CL_CHECK(clReleaseMemObject(data_device));
3750
+ return;
3751
+ }
3752
+ if (tensor->type == GGML_TYPE_Q8_0) {
3753
+ ggml_tensor_extra_cl_q8_0 * extra = (ggml_tensor_extra_cl_q8_0 *)tensor->extra;
2885
3754
 
2886
- ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
2887
- for (cl_mem buf : ctx->buffer) {
3755
+ cl_int err;
3756
+ cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
3757
+ ggml_nbytes(tensor), NULL, &err);
3758
+ CL_CHECK(err);
3759
+
3760
+ cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0;
3761
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
3762
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
3763
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
3764
+
3765
+ size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3766
+ size_t local_work_size[] = {1, 1, 1};
3767
+
3768
+ cl_event evt;
3769
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
3770
+ global_work_size, local_work_size, 0, NULL, &evt));
3771
+ CL_CHECK(clWaitForEvents(1, &evt));
3772
+ CL_CHECK(clEnqueueReadBuffer(
3773
+ queue, data_device, CL_TRUE, offset,
3774
+ size, data, 0, NULL, NULL));
3775
+ CL_CHECK(clReleaseMemObject(data_device));
3776
+ return;
3777
+ }
3778
+ #endif // GGML_OPENCL_SOA_Q
3779
+
3780
+ ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
3781
+
3782
+ CL_CHECK(clEnqueueReadBuffer(
3783
+ queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset,
3784
+ size, data, 0, NULL, NULL));
3785
+
3786
+ GGML_UNUSED(buffer);
3787
+ }
3788
+
3789
+ static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
3790
+ ggml_backend_dev_t dev = buffer->buft->device;
3791
+ ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
3792
+ cl_command_queue queue = backend_ctx->queue;
3793
+
3794
+ ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
3795
+ for (cl_mem buf : ctx->buffer) {
2888
3796
  CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));
2889
3797
  }
2890
3798
  CL_CHECK(clFinish(queue));
@@ -3178,6 +4086,19 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
3178
4086
  CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
3179
4087
  CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
3180
4088
  CL_CHECK(clFinish(queue));
4089
+ } else if (tensor->type == GGML_TYPE_MXFP4) {
4090
+ ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra;
4091
+ GGML_ASSERT(extra);
4092
+
4093
+ size_t size_q = ggml_nelements(tensor)/QK_MXFP4 * QK_MXFP4/2;
4094
+ size_t size_e = ggml_nelements(tensor)/QK_MXFP4 * sizeof(char);
4095
+ GGML_ASSERT(size_q + size_e == ggml_nbytes(tensor));
4096
+ buf_q = malloc(size_q);
4097
+ buf_d = malloc(size_e);
4098
+
4099
+ CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
4100
+ CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL));
4101
+ CL_CHECK(clFinish(queue));
3181
4102
  } else {
3182
4103
  // Read out the tensor from GPU memory.
3183
4104
  ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
@@ -3199,7 +4120,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
3199
4120
 
3200
4121
  // Open file and dump.
3201
4122
  char fname[512];
3202
- sprintf(fname, "./tensor-dumps/%s.txt", tensor->name);
4123
+ snprintf(fname, sizeof(fname), "./tensor-dumps/%s.txt", tensor->name);
3203
4124
  FILE * f = fopen(fname, "w");
3204
4125
  if (!f) {
3205
4126
  printf("Failed to open %s\n", fname);
@@ -3358,6 +4279,120 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
3358
4279
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
3359
4280
  }
3360
4281
 
4282
+ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4283
+ GGML_ASSERT(src0);
4284
+ GGML_ASSERT(src0->extra);
4285
+ GGML_ASSERT(src1);
4286
+ GGML_ASSERT(src1->extra);
4287
+ GGML_ASSERT(dst);
4288
+ GGML_ASSERT(dst->extra);
4289
+ GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
4290
+
4291
+ // ne0 = ne00
4292
+ // ne2 = ne02
4293
+ // ne3 = ne03
4294
+
4295
+ const int ne01 = src0->ne[1];
4296
+ const int ne02 = src0->ne[2];
4297
+ const int ne03 = src0->ne[3];
4298
+
4299
+ const cl_ulong nb01 = src0->nb[1];
4300
+ const cl_ulong nb02 = src0->nb[2];
4301
+ const cl_ulong nb03 = src0->nb[3];
4302
+
4303
+ const int ne11 = src1->ne[1];
4304
+ const int ne12 = src1->ne[2];
4305
+
4306
+ const cl_ulong nb10 = src1->nb[0];
4307
+ const cl_ulong nb11 = src1->nb[1];
4308
+ const cl_ulong nb12 = src1->nb[2];
4309
+
4310
+ const int ne0 = dst->ne[0];
4311
+
4312
+ const cl_ulong nb1 = dst->nb[1];
4313
+ const cl_ulong nb2 = dst->nb[2];
4314
+ const cl_ulong nb3 = dst->nb[3];
4315
+
4316
+ const int nblk0 = ne0/ggml_blck_size(dst->type);
4317
+
4318
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4319
+
4320
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4321
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4322
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4323
+
4324
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
4325
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
4326
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
4327
+
4328
+ cl_kernel kernel;
4329
+
4330
+ switch (dst->type) {
4331
+ case GGML_TYPE_F32:
4332
+ if (src1->type == GGML_TYPE_I64) {
4333
+ kernel = backend_ctx->kernel_set_rows_f32_i64;
4334
+ } else {
4335
+ kernel = backend_ctx->kernel_set_rows_f32_i32;
4336
+ }
4337
+ break;
4338
+ case GGML_TYPE_F16:
4339
+ if (src1->type == GGML_TYPE_I64) {
4340
+ kernel = backend_ctx->kernel_set_rows_f16_i64;
4341
+ } else {
4342
+ kernel = backend_ctx->kernel_set_rows_f16_i32;
4343
+ }
4344
+ break;
4345
+ default:
4346
+ GGML_ABORT("not implemented");
4347
+ }
4348
+
4349
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4350
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4351
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4352
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4353
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
4354
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4355
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01));
4356
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
4357
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
4358
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
4359
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
4360
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
4361
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
4362
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
4363
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));
4364
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &nblk0));
4365
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1));
4366
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2));
4367
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3));
4368
+
4369
+ int nth0 = 64;
4370
+ if (backend_ctx->gpu_family == INTEL) {
4371
+ nth0 = 32;
4372
+ } else if (backend_ctx->gpu_family == ADRENO) {
4373
+ nth0 = 64;
4374
+ }
4375
+
4376
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
4377
+ while (nth0 < nblk0 && nth0 < max_workgroup_size) {
4378
+ nth0 *= 2;
4379
+ }
4380
+
4381
+ int rows_per_workgroup = 1;
4382
+ if (nth0 > nblk0) {
4383
+ rows_per_workgroup = nth0 / nblk0;
4384
+ nth0 = nblk0;
4385
+ }
4386
+
4387
+ size_t global_work_size[] = {
4388
+ (size_t)(ne01 + rows_per_workgroup - 1)/rows_per_workgroup*nth0,
4389
+ (size_t)ne02*rows_per_workgroup,
4390
+ (size_t)ne03};
4391
+ size_t local_work_size[] = {(size_t)nth0, (size_t)rows_per_workgroup, 1};
4392
+
4393
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4394
+ }
4395
+
3361
4396
  static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3362
4397
  GGML_ASSERT(src0);
3363
4398
  GGML_ASSERT(src0->extra);
@@ -3366,35 +4401,35 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3366
4401
  GGML_ASSERT(dst);
3367
4402
  GGML_ASSERT(dst->extra);
3368
4403
 
3369
- const int ne00 = src0 ? src0->ne[0] : 0;
3370
- const int ne01 = src0 ? src0->ne[1] : 0;
3371
- const int ne02 = src0 ? src0->ne[2] : 0;
3372
- const int ne03 = src0 ? src0->ne[3] : 0;
4404
+ const int ne00 = src0->ne[0];
4405
+ const int ne01 = src0->ne[1];
4406
+ const int ne02 = src0->ne[2];
4407
+ const int ne03 = src0->ne[3];
3373
4408
 
3374
- const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
3375
- const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
3376
- const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
3377
- const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
4409
+ const cl_ulong nb00 = src0->nb[0];
4410
+ const cl_ulong nb01 = src0->nb[1];
4411
+ const cl_ulong nb02 = src0->nb[2];
4412
+ const cl_ulong nb03 = src0->nb[3];
3378
4413
 
3379
- const int ne10 = src1 ? src1->ne[0] : 0;
3380
- const int ne11 = src1 ? src1->ne[1] : 0;
3381
- const int ne12 = src1 ? src1->ne[2] : 0;
3382
- const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
4414
+ const int ne10 = src1->ne[0];
4415
+ const int ne11 = src1->ne[1];
4416
+ const int ne12 = src1->ne[2];
4417
+ const int ne13 = src1->ne[3];
3383
4418
 
3384
- const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
3385
- const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
3386
- const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
3387
- const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
4419
+ const cl_ulong nb10 = src1->nb[0];
4420
+ const cl_ulong nb11 = src1->nb[1];
4421
+ const cl_ulong nb12 = src1->nb[2];
4422
+ const cl_ulong nb13 = src1->nb[3];
3388
4423
 
3389
- const int ne0 = dst ? dst->ne[0] : 0;
3390
- const int ne1 = dst ? dst->ne[1] : 0;
3391
- const int ne2 = dst ? dst->ne[2] : 0;
3392
- const int ne3 = dst ? dst->ne[3] : 0;
4424
+ const int ne0 = dst->ne[0];
4425
+ const int ne1 = dst->ne[1];
4426
+ const int ne2 = dst->ne[2];
4427
+ const int ne3 = dst->ne[3];
3393
4428
 
3394
- const cl_ulong nb0 = dst ? dst->nb[0] : 0;
3395
- const cl_ulong nb1 = dst ? dst->nb[1] : 0;
3396
- const cl_ulong nb2 = dst ? dst->nb[2] : 0;
3397
- const cl_ulong nb3 = dst ? dst->nb[3] : 0;
4429
+ const cl_ulong nb0 = dst->nb[0];
4430
+ const cl_ulong nb1 = dst->nb[1];
4431
+ const cl_ulong nb2 = dst->nb[2];
4432
+ const cl_ulong nb3 = dst->nb[3];
3398
4433
 
3399
4434
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3400
4435
 
@@ -3406,59 +4441,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3406
4441
  cl_ulong offset1 = extra1->offset + src1->view_offs;
3407
4442
  cl_ulong offsetd = extrad->offset + dst->view_offs;
3408
4443
 
3409
- bool bcast_row = false;
3410
4444
  cl_kernel kernel;
3411
4445
 
3412
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
3413
- GGML_ASSERT(ggml_is_contiguous(src0));
4446
+ const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
3414
4447
 
3415
- // src1 is a row
4448
+ if (bcast_row) {
4449
+ GGML_ASSERT(ggml_is_contiguous(src0));
3416
4450
  GGML_ASSERT(ne11 == 1);
4451
+ }
3417
4452
 
3418
- bcast_row = true;
3419
- int ne = ne00 / 4;
3420
- kernel = backend_ctx->kernel_add_row;
3421
-
3422
- CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3423
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3424
- CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3425
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3426
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3427
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3428
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
4453
+ if (dst->type == GGML_TYPE_F32) {
4454
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
4455
+ if (bcast_row) {
4456
+ kernel = backend_ctx->kernel_add_row;
4457
+ const int ne = ne00 / 4;
4458
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4459
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4460
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4461
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4462
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
4463
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4464
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
4465
+ } else {
4466
+ kernel = backend_ctx->kernel_add;
4467
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4468
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4469
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4470
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4471
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
4472
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4473
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
4474
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
4475
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
4476
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
4477
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
4478
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
4479
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
4480
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
4481
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
4482
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
4483
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
4484
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
4485
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
4486
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
4487
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
4488
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
4489
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
4490
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
4491
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
4492
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
4493
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
4494
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
4495
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
4496
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
4497
+ }
4498
+ } else if (dst->type == GGML_TYPE_F16) {
4499
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
4500
+ GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
4501
+ const int type_src0 = (src0->type == GGML_TYPE_F32);
4502
+ const int type_src1 = (src1->type == GGML_TYPE_F32);
4503
+ if (bcast_row) {
4504
+ kernel = backend_ctx->kernel_add_row_f16;
4505
+ const int ne = ne00 / 4;
4506
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4507
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4508
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4509
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4510
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
4511
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4512
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
4513
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0));
4514
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1));
4515
+ } else {
4516
+ kernel = backend_ctx->kernel_add_f16;
4517
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4518
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4519
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4520
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4521
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
4522
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4523
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
4524
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
4525
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
4526
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
4527
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
4528
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
4529
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
4530
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
4531
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
4532
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
4533
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
4534
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
4535
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
4536
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
4537
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
4538
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
4539
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
4540
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
4541
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
4542
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
4543
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
4544
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
4545
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
4546
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
4547
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0));
4548
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1));
4549
+ }
3429
4550
  } else {
3430
- kernel = backend_ctx->kernel_add;
3431
-
3432
- CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3433
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3434
- CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
3435
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
3436
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
3437
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3438
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
3439
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
3440
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
3441
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
3442
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
3443
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
3444
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
3445
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
3446
- CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
3447
- CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
3448
- CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
3449
- CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
3450
- CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
3451
- CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
3452
- CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
3453
- CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
3454
- CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
3455
- CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
3456
- CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
3457
- CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
3458
- CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
3459
- CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
3460
- CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
3461
- CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
4551
+ GGML_ASSERT(false && "unsupported data types for add");
3462
4552
  }
3463
4553
 
3464
4554
  if (bcast_row) {
@@ -3468,19 +4558,88 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
3468
4558
 
3469
4559
  size_t * local_work_size_ptr = local_work_size;
3470
4560
  if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
3471
- local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
4561
+ local_work_size_ptr = nullptr;
3472
4562
  }
3473
4563
 
3474
- backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
4564
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
3475
4565
  } else {
3476
4566
  unsigned int nth = MIN(64, ne0);
3477
- size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
4567
+ size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
3478
4568
  size_t local_work_size[] = {nth, 1, 1};
3479
4569
 
3480
4570
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
3481
4571
  }
3482
4572
  }
3483
4573
 
4574
+ static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4575
+ GGML_ASSERT(src0);
4576
+ GGML_ASSERT(src0->extra);
4577
+ GGML_ASSERT(src1);
4578
+ GGML_ASSERT(src1->extra);
4579
+ GGML_ASSERT(dst);
4580
+ GGML_ASSERT(dst->extra);
4581
+
4582
+ const ggml_tensor * src2 = dst->src[2];
4583
+ GGML_ASSERT(src2);
4584
+ GGML_ASSERT(src2->extra);
4585
+
4586
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
4587
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
4588
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
4589
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
4590
+
4591
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
4592
+
4593
+ const int ne00 = src0->ne[0];
4594
+ const int ne01 = src0->ne[1];
4595
+ const int ne02 = src0->ne[2];
4596
+
4597
+ const cl_ulong nb01 = src0->nb[1];
4598
+ const cl_ulong nb02 = src0->nb[2];
4599
+
4600
+ const cl_ulong nb11 = src1->nb[1];
4601
+
4602
+ const cl_ulong nb21 = src2->nb[1];
4603
+
4604
+ const int ne0 = dst->ne[0];
4605
+ const int ne1 = dst->ne[1];
4606
+
4607
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4608
+
4609
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4610
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4611
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
4612
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4613
+
4614
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
4615
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
4616
+ cl_ulong offset2 = extra2->offset + src2->view_offs;
4617
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
4618
+
4619
+ cl_kernel kernel = backend_ctx->kernel_add_id;
4620
+
4621
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4622
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4623
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4624
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4625
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
4626
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
4627
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
4628
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
4629
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
4630
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
4631
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
4632
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
4633
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
4634
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
4635
+
4636
+ int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
4637
+ size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
4638
+ size_t local_work_size[] = { (size_t)nth, 1, 1 };
4639
+
4640
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4641
+ }
4642
+
3484
4643
  static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3485
4644
  GGML_ASSERT(src0);
3486
4645
  GGML_ASSERT(src0->extra);
@@ -3489,35 +4648,39 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
3489
4648
  GGML_ASSERT(dst);
3490
4649
  GGML_ASSERT(dst->extra);
3491
4650
 
3492
- const int ne00 = src0 ? src0->ne[0] : 0;
3493
- const int ne01 = src0 ? src0->ne[1] : 0;
3494
- const int ne02 = src0 ? src0->ne[2] : 0;
3495
- const int ne03 = src0 ? src0->ne[3] : 0;
4651
+ GGML_ASSERT(src0->type == src1->type);
4652
+ GGML_ASSERT(src0->type == dst->type);
4653
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3496
4654
 
3497
- const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
3498
- const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
3499
- const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
3500
- const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
4655
+ const int ne00 = src0->ne[0];
4656
+ const int ne01 = src0->ne[1];
4657
+ const int ne02 = src0->ne[2];
4658
+ const int ne03 = src0->ne[3];
3501
4659
 
3502
- const int ne10 = src1 ? src1->ne[0] : 0;
3503
- const int ne11 = src1 ? src1->ne[1] : 0;
3504
- const int ne12 = src1 ? src1->ne[2] : 0;
3505
- const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
4660
+ const cl_ulong nb00 = src0->nb[0];
4661
+ const cl_ulong nb01 = src0->nb[1];
4662
+ const cl_ulong nb02 = src0->nb[2];
4663
+ const cl_ulong nb03 = src0->nb[3];
3506
4664
 
3507
- const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
3508
- const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
3509
- const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
3510
- const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
4665
+ const int ne10 = src1->ne[0];
4666
+ const int ne11 = src1->ne[1];
4667
+ const int ne12 = src1->ne[2];
4668
+ const int ne13 = src1->ne[3]; UNUSED(ne13);
3511
4669
 
3512
- const int ne0 = dst ? dst->ne[0] : 0;
3513
- const int ne1 = dst ? dst->ne[1] : 0;
3514
- const int ne2 = dst ? dst->ne[2] : 0;
3515
- const int ne3 = dst ? dst->ne[3] : 0;
4670
+ const cl_ulong nb10 = src1->nb[0];
4671
+ const cl_ulong nb11 = src1->nb[1];
4672
+ const cl_ulong nb12 = src1->nb[2];
4673
+ const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
4674
+
4675
+ const int ne0 = dst->ne[0];
4676
+ const int ne1 = dst->ne[1];
4677
+ const int ne2 = dst->ne[2];
4678
+ const int ne3 = dst->ne[3];
3516
4679
 
3517
- const cl_ulong nb0 = dst ? dst->nb[0] : 0;
3518
- const cl_ulong nb1 = dst ? dst->nb[1] : 0;
3519
- const cl_ulong nb2 = dst ? dst->nb[2] : 0;
3520
- const cl_ulong nb3 = dst ? dst->nb[3] : 0;
4680
+ const cl_ulong nb0 = dst->nb[0];
4681
+ const cl_ulong nb1 = dst->nb[1];
4682
+ const cl_ulong nb2 = dst->nb[2];
4683
+ const cl_ulong nb3 = dst->nb[3];
3521
4684
 
3522
4685
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3523
4686
 
@@ -3540,7 +4703,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
3540
4703
 
3541
4704
  bcast_row = true;
3542
4705
  int ne = ne00 / 4;
3543
- kernel = backend_ctx->kernel_mul_row;
4706
+
4707
+ if (src0->type == GGML_TYPE_F32) {
4708
+ kernel = backend_ctx->kernel_mul_row;
4709
+ } else {
4710
+ kernel = backend_ctx->kernel_mul_row_f16;
4711
+ }
3544
4712
 
3545
4713
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3546
4714
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3550,7 +4718,11 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
3550
4718
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3551
4719
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3552
4720
  } else {
3553
- kernel = backend_ctx->kernel_mul;
4721
+ if (src0->type == GGML_TYPE_F32) {
4722
+ kernel = backend_ctx->kernel_mul;
4723
+ } else {
4724
+ kernel = backend_ctx->kernel_mul_f16;
4725
+ }
3554
4726
 
3555
4727
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3556
4728
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3612,6 +4784,10 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
3612
4784
  GGML_ASSERT(dst);
3613
4785
  GGML_ASSERT(dst->extra);
3614
4786
 
4787
+ GGML_ASSERT(src0->type == src1->type);
4788
+ GGML_ASSERT(src0->type == dst->type);
4789
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
4790
+
3615
4791
  const int ne00 = src0->ne[0];
3616
4792
  const int ne01 = src0->ne[1];
3617
4793
  const int ne02 = src0->ne[2];
@@ -3660,7 +4836,12 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
3660
4836
 
3661
4837
  bcast_row = true;
3662
4838
  int ne = ne00 / 4;
3663
- kernel = backend_ctx->kernel_div_row;
4839
+
4840
+ if (src0->type == GGML_TYPE_F32) {
4841
+ kernel = backend_ctx->kernel_div_row;
4842
+ } else {
4843
+ kernel = backend_ctx->kernel_div_row_f16;
4844
+ }
3664
4845
 
3665
4846
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3666
4847
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3670,7 +4851,11 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
3670
4851
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3671
4852
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3672
4853
  } else {
3673
- kernel = backend_ctx->kernel_div;
4854
+ if (src0->type == GGML_TYPE_F32) {
4855
+ kernel = backend_ctx->kernel_div;
4856
+ } else {
4857
+ kernel = backend_ctx->kernel_div_f16;
4858
+ }
3674
4859
 
3675
4860
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3676
4861
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3720,6 +4905,10 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
3720
4905
  GGML_ASSERT(dst);
3721
4906
  GGML_ASSERT(dst->extra);
3722
4907
 
4908
+ GGML_ASSERT(src0->type == src1->type);
4909
+ GGML_ASSERT(src0->type == dst->type);
4910
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
4911
+
3723
4912
  const int ne00 = src0->ne[0];
3724
4913
  const int ne01 = src0->ne[1];
3725
4914
  const int ne02 = src0->ne[2];
@@ -3768,7 +4957,12 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
3768
4957
 
3769
4958
  bcast_row = true;
3770
4959
  int ne = ne00 / 4;
3771
- kernel = backend_ctx->kernel_sub_row;
4960
+
4961
+ if (src0->type == GGML_TYPE_F32) {
4962
+ kernel = backend_ctx->kernel_sub_row;
4963
+ } else {
4964
+ kernel = backend_ctx->kernel_sub_row_f16;
4965
+ }
3772
4966
 
3773
4967
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3774
4968
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3778,7 +4972,11 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
3778
4972
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
3779
4973
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
3780
4974
  } else {
3781
- kernel = backend_ctx->kernel_sub;
4975
+ if (src0->type == GGML_TYPE_F32) {
4976
+ kernel = backend_ctx->kernel_sub;
4977
+ } else {
4978
+ kernel = backend_ctx->kernel_sub_f16;
4979
+ }
3782
4980
 
3783
4981
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3784
4982
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3858,6 +5056,44 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
3858
5056
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
3859
5057
  }
3860
5058
 
5059
+ static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5060
+ GGML_ASSERT(src0);
5061
+ GGML_ASSERT(src0->extra);
5062
+ GGML_ASSERT(dst);
5063
+ GGML_ASSERT(dst->extra);
5064
+
5065
+ UNUSED(src1);
5066
+
5067
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5068
+
5069
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5070
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5071
+
5072
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5073
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5074
+
5075
+ cl_kernel kernel;
5076
+
5077
+ int n = ggml_nelements(dst);
5078
+
5079
+ if (n % 4 == 0) {
5080
+ kernel = backend_ctx->kernel_gelu_erf_4;
5081
+ n /= 4;
5082
+ } else {
5083
+ kernel = backend_ctx->kernel_gelu_erf;
5084
+ }
5085
+
5086
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5087
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5088
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5089
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5090
+
5091
+ size_t global_work_size[] = {(size_t)n, 1, 1};
5092
+ size_t local_work_size[] = {64, 1, 1};
5093
+
5094
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5095
+ }
5096
+
3861
5097
  static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3862
5098
  GGML_ASSERT(src0);
3863
5099
  GGML_ASSERT(src0->extra);
@@ -4188,6 +5424,251 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
4188
5424
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4189
5425
  }
4190
5426
 
5427
+ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
5428
+ GGML_ASSERT(mul_tensor);
5429
+ GGML_ASSERT(rms_norm_tensor);
5430
+
5431
+ // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
5432
+ const ggml_tensor * src0 = rms_norm_tensor->src[0];
5433
+ const ggml_tensor * src1;
5434
+ if (mul_tensor->src[0] == rms_norm_tensor) {
5435
+ src1 = mul_tensor->src[1];
5436
+ } else if (mul_tensor->src[1] == rms_norm_tensor) {
5437
+ src1 = mul_tensor->src[0];
5438
+ } else {
5439
+ GGML_ASSERT(false && "Invalid args for rms_norm and mul");
5440
+ }
5441
+ const ggml_tensor * dst = mul_tensor;
5442
+
5443
+ GGML_ASSERT(src0);
5444
+ GGML_ASSERT(src0->extra);
5445
+ GGML_ASSERT(src1);
5446
+ GGML_ASSERT(src1->extra);
5447
+ GGML_ASSERT(dst);
5448
+ GGML_ASSERT(dst->extra);
5449
+
5450
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5451
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5452
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5453
+
5454
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5455
+ cl_ulong offset1 = extra1->offset + src0->view_offs;
5456
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5457
+
5458
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5459
+
5460
+ float eps;
5461
+ memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
5462
+
5463
+ const int ne00 = src0->ne[0];
5464
+ const int ne01 = src0->ne[1];
5465
+ const int ne02 = src0->ne[2];
5466
+ const int ne03 = src0->ne[3];
5467
+
5468
+ const cl_ulong nb01 = src0->nb[1];
5469
+ const cl_ulong nb02 = src0->nb[2];
5470
+ const cl_ulong nb03 = src0->nb[3];
5471
+
5472
+ const int ne10 = src1->ne[0];
5473
+ const int ne11 = src1->ne[1];
5474
+ const int ne12 = src1->ne[2];
5475
+ const int ne13 = src1->ne[3];
5476
+
5477
+ const cl_ulong nb11 = src1->nb[1];
5478
+ const cl_ulong nb12 = src1->nb[2];
5479
+ const cl_ulong nb13 = src1->nb[3];
5480
+
5481
+ const cl_ulong nb1 = dst->nb[1];
5482
+ const cl_ulong nb2 = dst->nb[2];
5483
+ const cl_ulong nb3 = dst->nb[3];
5484
+
5485
+ GGML_ASSERT(ne00 % 4 == 0);
5486
+
5487
+ size_t sgs;
5488
+ if (backend_ctx->gpu_family == ADRENO) {
5489
+ sgs = 64;
5490
+ } else if (backend_ctx->gpu_family == INTEL) {
5491
+ sgs = 32;
5492
+ } else {
5493
+ GGML_ASSERT(false && "Unsupported GPU");
5494
+ }
5495
+
5496
+ cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
5497
+
5498
+ int nth = sgs;
5499
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
5500
+ while (nth < ne00 && nth < max_workgroup_size) {
5501
+ nth *= 2;
5502
+ }
5503
+ nth = MIN(nth, max_workgroup_size);
5504
+ nth = MIN(nth, ne00);
5505
+
5506
+ size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
5507
+ size_t local_work_size[] = {(size_t)nth, 1, 1};
5508
+
5509
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5510
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5511
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5512
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5513
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5514
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5515
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5516
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5517
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5518
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
5519
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
5520
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
5521
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
5522
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10));
5523
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11));
5524
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12));
5525
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13));
5526
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
5527
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
5528
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
5529
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
5530
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
5531
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
5532
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
5533
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
5534
+
5535
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5536
+ }
5537
+
5538
+ static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5539
+ GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);
5540
+
5541
+ const ggml_tensor * src0 = norm_tensor->src[0];
5542
+ const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
5543
+ const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
5544
+ const ggml_tensor * dst = add_tensor;
5545
+
5546
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5547
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5548
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
5549
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5550
+
5551
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5552
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
5553
+ cl_ulong offset2 = extra2->offset + src2->view_offs;
5554
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5555
+
5556
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5557
+
5558
+ float eps;
5559
+ memcpy(&eps, norm_tensor->op_params, sizeof(float));
5560
+
5561
+ const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
5562
+ const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
5563
+ const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];
5564
+ const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
5565
+ const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];
5566
+ const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];
5567
+ const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];
5568
+
5569
+ size_t sgs;
5570
+ if (backend_ctx->gpu_family == ADRENO) sgs = 64;
5571
+ else if (backend_ctx->gpu_family == INTEL) sgs = 32;
5572
+ else GGML_ASSERT(false && "Unsupported GPU");
5573
+
5574
+ cl_kernel kernel = backend_ctx->kernel_norm_mul_add;
5575
+
5576
+ int nth = sgs;
5577
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
5578
+ while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;
5579
+ nth = MIN(nth, max_workgroup_size);
5580
+ nth = MIN(nth, ne00/4);
5581
+
5582
+ size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
5583
+ size_t lws[] = {(size_t)nth, 1, 1};
5584
+ size_t num_subgroups = (nth + sgs - 1) / sgs;
5585
+
5586
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5587
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5588
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5589
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5590
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
5591
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
5592
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
5593
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
5594
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
5595
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
5596
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
5597
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));
5598
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));
5599
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));
5600
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));
5601
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));
5602
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));
5603
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));
5604
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));
5605
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
5606
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
5607
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
5608
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));
5609
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));
5610
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));
5611
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));
5612
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));
5613
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));
5614
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));
5615
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));
5616
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));
5617
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));
5618
+ CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));
5619
+ CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));
5620
+
5621
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);
5622
+ }
5623
+
5624
+ static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5625
+ GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);
5626
+
5627
+ const ggml_tensor * src0 = gn_tensor->src[0];
5628
+ const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
5629
+ const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
5630
+ const ggml_tensor * dst = add_tensor;
5631
+
5632
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5633
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5634
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
5635
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5636
+
5637
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5638
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
5639
+ cl_ulong offset2 = extra2->offset + src2->view_offs;
5640
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5641
+
5642
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5643
+
5644
+ int groups;
5645
+ float eps;
5646
+ memcpy(&groups, gn_tensor->op_params, sizeof(int));
5647
+ memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));
5648
+
5649
+ cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;
5650
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
5651
+ int ne = ggml_nelements(src0);
5652
+ int group_size = ne / groups;
5653
+
5654
+ size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };
5655
+ size_t gws[] = { (size_t)groups * lws[0] };
5656
+
5657
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5658
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5659
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5660
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5661
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
5662
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
5663
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
5664
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
5665
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));
5666
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
5667
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));
5668
+
5669
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
5670
+ }
5671
+
4191
5672
  static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4192
5673
  GGML_ASSERT(src0);
4193
5674
  GGML_ASSERT(src0->extra);
@@ -4453,7 +5934,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4453
5934
 
4454
5935
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4455
5936
 
4456
- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
5937
+ const int mode_flags = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
5938
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
4457
5939
  cl_kernel kernel = nullptr;
4458
5940
 
4459
5941
  if (mode == GGML_SCALE_MODE_NEAREST) {
@@ -4484,18 +5966,22 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4484
5966
  const cl_ulong nb02 = src0->nb[2];
4485
5967
  const cl_ulong nb03 = src0->nb[3];
4486
5968
 
4487
- const int ne00_src = src0->ne[0];
4488
- const int ne01_src = src0->ne[1];
5969
+ const int ne00 = src0->ne[0];
5970
+ const int ne01 = src0->ne[1];
5971
+ const int ne02 = src0->ne[2];
5972
+ const int ne03 = src0->ne[3];
4489
5973
 
4490
- const int ne10_dst = dst->ne[0];
4491
- const int ne11_dst = dst->ne[1];
4492
- const int ne12_dst = dst->ne[2];
4493
- const int ne13_dst = dst->ne[3];
5974
+ const int ne0 = dst->ne[0];
5975
+ const int ne1 = dst->ne[1];
5976
+ const int ne2 = dst->ne[2];
5977
+ const int ne3 = dst->ne[3];
4494
5978
 
4495
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
4496
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
4497
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
4498
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
5979
+ float sf0 = (float)ne0 / ne00;
5980
+ float sf1 = (float)ne1 / ne01;
5981
+ float sf2 = (float)ne2 / ne02;
5982
+ float sf3 = (float)ne3 / ne03;
5983
+
5984
+ float pixel_offset = 0.5f;
4499
5985
 
4500
5986
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4501
5987
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
@@ -4507,29 +5993,36 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4507
5993
  CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03));
4508
5994
 
4509
5995
  if (mode == GGML_SCALE_MODE_NEAREST) {
4510
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst));
4511
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst));
4512
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst));
4513
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst));
5996
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0));
5997
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne1));
5998
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne2));
5999
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne3));
4514
6000
  CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0));
4515
6001
  CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1));
4516
6002
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2));
4517
6003
  CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
4518
6004
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
4519
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src));
4520
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src));
4521
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst));
4522
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst));
4523
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst));
4524
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst));
6005
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
6006
+ sf0 = (float)(ne0 - 1) / (ne00 - 1);
6007
+ sf1 = (float)(ne1 - 1) / (ne01 - 1);
6008
+ pixel_offset = 0.0f;
6009
+ }
6010
+
6011
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
6012
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
6013
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne0));
6014
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne1));
6015
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne2));
6016
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne3));
4525
6017
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0));
4526
6018
  CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1));
4527
6019
  CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2));
4528
6020
  CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3));
6021
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &pixel_offset));
4529
6022
  }
4530
6023
 
4531
6024
 
4532
- size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst;
6025
+ size_t dst_total_elements = (size_t)ne0 * ne1 * ne2 * ne3;
4533
6026
  if (dst_total_elements == 0) {
4534
6027
  return;
4535
6028
  }
@@ -4626,12 +6119,12 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con
4626
6119
  } else {
4627
6120
  cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous;
4628
6121
 
4629
- long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
6122
+ cl_long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
4630
6123
  cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
4631
6124
 
4632
6125
  cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
4633
6126
 
4634
- long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3];
6127
+ cl_long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3];
4635
6128
  cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3];
4636
6129
 
4637
6130
 
@@ -4642,10 +6135,10 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con
4642
6135
  CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device));
4643
6136
  CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst));
4644
6137
 
4645
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(long), &ne00));
4646
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(long), &ne01));
4647
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(long), &ne02));
4648
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(long), &ne03));
6138
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &ne00));
6139
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &ne01));
6140
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &ne02));
6141
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &ne03));
4649
6142
  CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
4650
6143
  CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
4651
6144
  CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
@@ -4656,10 +6149,10 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con
4656
6149
  CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
4657
6150
  CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
4658
6151
 
4659
- CL_CHECK(clSetKernelArg(kernel, 18, sizeof(long), &d_ne0));
4660
- CL_CHECK(clSetKernelArg(kernel, 19, sizeof(long), &d_ne1));
4661
- CL_CHECK(clSetKernelArg(kernel, 20, sizeof(long), &d_ne2));
4662
- CL_CHECK(clSetKernelArg(kernel, 21, sizeof(long), &d_ne3));
6152
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_long), &d_ne0));
6153
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_long), &d_ne1));
6154
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_long), &d_ne2));
6155
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_long), &d_ne3));
4663
6156
  CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0));
4664
6157
  CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1));
4665
6158
  CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2));
@@ -4718,6 +6211,270 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
4718
6211
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
4719
6212
  }
4720
6213
 
6214
+ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
6215
+ const ggml_tensor * v = dst->src[2];
6216
+ const ggml_tensor * mask = dst->src[3];
6217
+ const ggml_tensor * sinks = dst->src[4];
6218
+ GGML_ASSERT(q->extra);
6219
+ GGML_ASSERT(k->extra);
6220
+ GGML_ASSERT(v->extra);
6221
+ GGML_ASSERT(dst->extra);
6222
+ if (mask) {
6223
+ GGML_ASSERT(mask->extra);
6224
+ }
6225
+ if (sinks) {
6226
+ GGML_ASSERT(sinks->extra);
6227
+ }
6228
+
6229
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6230
+
6231
+ const int n_q = q->ne[1];
6232
+ const int n_kv = k->ne[1];
6233
+ const int d_head_q = q->ne[0];
6234
+ const int d_head_v = v->ne[0];
6235
+ const int n_head = q->ne[2];
6236
+ const int n_head_kv = k->ne[2];
6237
+ const int n_batch = q->ne[3];
6238
+
6239
+ cl_kernel kernel = NULL;
6240
+
6241
+ const bool is_f16 = q->type == GGML_TYPE_F16;
6242
+ const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
6243
+ const std::pair<int, int> dk_dv = {d_head_q, d_head_v};
6244
+
6245
+ if (n_q == 1) {
6246
+ if (is_mixed) {
6247
+ kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
6248
+ } else if (is_f16) {
6249
+ kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
6250
+ } else {
6251
+ kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
6252
+ }
6253
+ } else {
6254
+ if (is_mixed) {
6255
+ kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
6256
+ } else if (is_f16) {
6257
+ kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
6258
+ } else {
6259
+ kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
6260
+ }
6261
+ }
6262
+ GGML_ASSERT(kernel != NULL);
6263
+
6264
+ ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
6265
+ ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
6266
+ ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
6267
+ ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
6268
+ ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
6269
+ ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
6270
+
6271
+ cl_ulong offset_q = extra_q->offset + q->view_offs;
6272
+ cl_ulong offset_k = extra_k->offset + k->view_offs;
6273
+ cl_ulong offset_v = extra_v->offset + v->view_offs;
6274
+ cl_ulong offset_o = extra_o->offset + dst->view_offs;
6275
+ cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
6276
+ cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
6277
+ cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
6278
+ cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
6279
+
6280
+ const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
6281
+ const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
6282
+ const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
6283
+ const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
6284
+ const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
6285
+ const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
6286
+ const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
6287
+ const int mask_ne2 = mask ? mask->ne[2] : 0;
6288
+ const int mask_ne3 = mask ? mask->ne[3] : 0;
6289
+
6290
+ float scale, max_bias, logit_softcap;
6291
+ const float * params = (const float *)dst->op_params;
6292
+ scale = params[0];
6293
+ max_bias = params[1];
6294
+ logit_softcap = params[2];
6295
+
6296
+ const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);
6297
+
6298
+ const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
6299
+ const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
6300
+ const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
6301
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
6302
+
6303
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
6304
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
6305
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device));
6306
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
6307
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device));
6308
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
6309
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device));
6310
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
6311
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale));
6312
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q));
6313
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv));
6314
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal));
6315
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head));
6316
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
6317
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
6318
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
6319
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
6320
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias));
6321
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0));
6322
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1));
6323
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val));
6324
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap));
6325
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv));
6326
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer));
6327
+ CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
6328
+ CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
6329
+ CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
6330
+ CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
6331
+ CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
6332
+ CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
6333
+ CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
6334
+ CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
6335
+
6336
+ if (n_q == 1) {
6337
+ const size_t wg_size = 64;
6338
+ size_t local_work_size[] = { wg_size, 1 };
6339
+ size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
6340
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
6341
+ } else {
6342
+ const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
6343
+ const size_t wg_size = block_m;
6344
+ size_t local_work_size[] = { wg_size, 1 };
6345
+ size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
6346
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
6347
+ }
6348
+ }
6349
+
6350
+ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6351
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6352
+
6353
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6354
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
6355
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6356
+
6357
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
6358
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
6359
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
6360
+
6361
+ const int M = src0->ne[1];
6362
+ const int N = src1->ne[1];
6363
+ const int K = src0->ne[0];
6364
+
6365
+ cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;
6366
+
6367
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M));
6368
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N));
6369
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K));
6370
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device));
6371
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));
6372
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
6373
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
6374
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
6375
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
6376
+
6377
+ // Tiling parameters. These need to be tuned for optimal performance.
6378
+ // They must match the #defines in the kernel mul_mat_f16_f32.cl.
6379
+ //
6380
+ // OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN.
6381
+ // TPWM / TPWN: Threads per Work-group. This is the work-group size.
6382
+ // OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements.
6383
+ //
6384
+ // The following relationships must hold:
6385
+ // OPWM = TPWM * OPTM
6386
+ // OPWN = TPWN * OPTN
6387
+ //
6388
+ const int OPWM = 64;
6389
+ const int OPWN = 64;
6390
+ const int TPWM = 16;
6391
+ const int TPWN = 8;
6392
+
6393
+ size_t local_work_size[2] = { TPWM, TPWN };
6394
+ size_t global_work_size[2] = {
6395
+ (size_t) ((M + OPWM - 1) / OPWM) * TPWM,
6396
+ (size_t) ((N + OPWN - 1) / OPWN) * TPWN,
6397
+ };
6398
+
6399
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
6400
+ }
6401
+
6402
+ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6403
+ GGML_TENSOR_BINARY_OP_LOCALS;
6404
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6405
+
6406
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6407
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
6408
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6409
+
6410
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
6411
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
6412
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
6413
+
6414
+ const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
6415
+ const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
6416
+
6417
+ const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
6418
+ const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
6419
+ const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
6420
+
6421
+ const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
6422
+ const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
6423
+ const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
6424
+
6425
+ const int64_t NPQ = (int64_t)N * OW * OH;
6426
+
6427
+ const uint32_t BS_K = 64;
6428
+ const uint32_t BS_NPQ = 64;
6429
+ const uint32_t BS_CRS = 16;
6430
+ const uint32_t VEC_SIZE = 4;
6431
+
6432
+ const uint32_t TS_K = 4;
6433
+ const uint32_t TS_NPQ = 8;
6434
+
6435
+ const uint32_t WG_K = BS_K / TS_K;
6436
+ const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
6437
+
6438
+ auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
6439
+ const uint32_t NB_K = splitWork(Cout, BS_K);
6440
+ const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
6441
+
6442
+ cl_kernel kernel;
6443
+ size_t shmem_size;
6444
+
6445
+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
6446
+ kernel = backend_ctx->kernel_conv_2d_f16;
6447
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
6448
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
6449
+ kernel = backend_ctx->kernel_conv_2d_f32;
6450
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
6451
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
6452
+ kernel = backend_ctx->kernel_conv_2d_f16_f32;
6453
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
6454
+ } else {
6455
+ GGML_ASSERT(false && "Unsupported data type combination for conv2d");
6456
+ }
6457
+
6458
+ cl_uint idx = 0;
6459
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
6460
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
6461
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
6462
+ CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
6463
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
6464
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
6465
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
6466
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
6467
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
6468
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
6469
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
6470
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
6471
+
6472
+ size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
6473
+ size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
6474
+
6475
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
6476
+ }
6477
+
4721
6478
  static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4722
6479
  GGML_ASSERT(src0);
4723
6480
  GGML_ASSERT(src0->extra);
@@ -4741,6 +6498,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
4741
6498
 
4742
6499
  #ifdef GGML_OPENCL_SOA_Q
4743
6500
  ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
6501
+ ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
6502
+ ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
4744
6503
  #endif
4745
6504
 
4746
6505
  const int ne00 = src0 ? src0->ne[0] : 0;
@@ -5070,12 +6829,107 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
5070
6829
  CL_CHECK(clReleaseMemObject(B_d_input_image));
5071
6830
  CL_CHECK(clReleaseMemObject(C_d));
5072
6831
  }
5073
- // <--------------------------------------------> //
6832
+ // <--------------------------------------------> //
6833
+
6834
+ return;
6835
+ }
6836
+ } // if (ne01 && ne1)
6837
+ #endif // GGML_OPENCL_USE_ADRENO_KERNELS
6838
+
6839
+ // GEMM using local memory
6840
+ // Current BK = 16, so ne00 % 16 == 0
6841
+ if (ggml_is_contiguous(src0) &&
6842
+ ggml_is_contiguous(src1) &&
6843
+ src1t == GGML_TYPE_F32 &&
6844
+ ne00 % 16 == 0 &&
6845
+ ne11 > 1) {
6846
+ switch(src0t) {
6847
+ case GGML_TYPE_F32: {
6848
+ kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm;
6849
+ nth0 = 128; // calculated as (BM*BN)/(TM*TN)
6850
+
6851
+ int batch_stride_a = ne00*ne01;
6852
+ int batch_stride_b = ne10*ne11;
6853
+ int batch_stride_d = ne0*ne1;
6854
+
6855
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6856
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6857
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
6858
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6859
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6860
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6861
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
6862
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
6863
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
6864
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
6865
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
6866
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
6867
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
6868
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
6869
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
6870
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
6871
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
6872
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
6873
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
6874
+
6875
+ // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
6876
+ size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
6877
+ size_t local_work_size[] = {(size_t)nth0, 1, 1};
6878
+
6879
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6880
+ return;
6881
+ }
6882
+ case GGML_TYPE_F16: {
6883
+ kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm;
6884
+ nth0 = 128; // calculated as (BM*BN)/(TM*TN)
6885
+
6886
+ int batch_stride_a = ne00*ne01;
6887
+ int batch_stride_b = ne10*ne11;
6888
+ int batch_stride_d = ne0*ne1;
6889
+
6890
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6891
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6892
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
6893
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6894
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6895
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6896
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
6897
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
6898
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
6899
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
6900
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
6901
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
6902
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
6903
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
6904
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
6905
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
6906
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
6907
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
6908
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
6909
+
6910
+ // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
6911
+ size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
6912
+ size_t local_work_size[] = {(size_t)nth0, 1, 1};
6913
+
6914
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6915
+ return;
6916
+ }
6917
+ default:
6918
+ break;
6919
+ }
6920
+ }
5074
6921
 
6922
+ if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
6923
+ src0->ne[1] > 32 && // M > 32
6924
+ src1->ne[1] > 32 && // N > 32
6925
+ src0->ne[0] > 32 && // K > 32
6926
+ src0->ne[2] == 1 && src0->ne[3] == 1 &&
6927
+ src1->ne[2] == 1 && src1->ne[3] == 1 &&
6928
+ ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
6929
+ backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
6930
+ ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
5075
6931
  return;
5076
6932
  }
5077
- } // if (ne01 && ne1)
5078
- #endif // GGML_OPENCL_USE_ADRENO_KERNELS
5079
6933
 
5080
6934
  if (!ggml_is_transposed(src0) &&
5081
6935
  !ggml_is_transposed(src1) &&
@@ -5315,7 +7169,84 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
5315
7169
  #endif // GGML_OPENCL_SOA_Q
5316
7170
  break;
5317
7171
  case GGML_TYPE_Q4_1:
5318
- case GGML_TYPE_Q8_0:
7172
+ case GGML_TYPE_Q8_0: {
7173
+ #ifdef GGML_OPENCL_SOA_Q
7174
+ kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat;
7175
+
7176
+ // nth0 - subgroup size
7177
+ // nth1 - number of subgroups per workgroup
7178
+ // ndst - number of output values per workgroup = output per subgroup * number of subgroups
7179
+ if (backend_ctx->gpu_family == INTEL) {
7180
+ nth0 = 16;
7181
+ nth1 = 2;
7182
+ ndst = nth1*4;
7183
+ } else if (backend_ctx->gpu_family == ADRENO) {
7184
+ nth0 = 64;
7185
+ nth1 = 2;
7186
+ ndst = nth1*4;
7187
+ } else {
7188
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7189
+ }
7190
+
7191
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
7192
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
7193
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7194
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7195
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
7196
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
7197
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
7198
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
7199
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
7200
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
7201
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
7202
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
7203
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
7204
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
7205
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
7206
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
7207
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
7208
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
7209
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
7210
+ #else
7211
+ kernel = backend_ctx->kernel_mul_mv_q8_0_f32;
7212
+
7213
+ // nth0 - subgroup size
7214
+ // nth1 - number of subgroups per workgroup
7215
+ // ndst - number of output values per workgroup = output per subgroup * number of subgroups
7216
+ if (backend_ctx->gpu_family == INTEL) {
7217
+ nth0 = 16;
7218
+ nth1 = 2;
7219
+ ndst = nth1*4;
7220
+ } else if (backend_ctx->gpu_family == ADRENO) {
7221
+ nth0 = 64;
7222
+ nth1 = 2;
7223
+ ndst = nth1*4;
7224
+ } else {
7225
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7226
+ }
7227
+
7228
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7229
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7230
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7231
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7232
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
7233
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
7234
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
7235
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
7236
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
7237
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
7238
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
7239
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
7240
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
7241
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
7242
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
7243
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
7244
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
7245
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
7246
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
7247
+ #endif // GGML_OPENCL_SOA_Q
7248
+ break;
7249
+ }
5319
7250
  case GGML_TYPE_Q2_K:
5320
7251
  case GGML_TYPE_Q3_K:
5321
7252
  case GGML_TYPE_Q4_K:
@@ -5349,11 +7280,87 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
5349
7280
  CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
5350
7281
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
5351
7282
  break;
7283
+ case GGML_TYPE_MXFP4: {
7284
+ #ifdef GGML_OPENCL_SOA_Q
7285
+ kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat;
7286
+
7287
+ cl_mem q;
7288
+ if (backend_ctx->gpu_family == INTEL) {
7289
+ nth0 = 16;
7290
+ nth1 = 2;
7291
+ ndst = nth1*2;
7292
+
7293
+ q = extra0_mxfp4->q;
7294
+ } else if (backend_ctx->gpu_family == ADRENO) {
7295
+ nth0 = 64;
7296
+ nth1 = 2;
7297
+ ndst = nth1*2;
7298
+
7299
+ q = extra0_mxfp4->q_img;
7300
+ } else {
7301
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7302
+ }
7303
+
7304
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q));
7305
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e));
7306
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7307
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7308
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
7309
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
7310
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
7311
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
7312
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
7313
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
7314
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
7315
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
7316
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));
7317
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));
7318
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0));
7319
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1));
7320
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
7321
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
7322
+ #else
7323
+ kernel = backend_ctx->kernel_mul_mv_mxfp4_f32;
7324
+
7325
+ if (backend_ctx->gpu_family == INTEL) {
7326
+ nth0 = 16;
7327
+ nth1 = 2;
7328
+ ndst = nth1*2;
7329
+ } else if (backend_ctx->gpu_family == ADRENO) {
7330
+ nth0 = 64;
7331
+ nth1 = 2;
7332
+ ndst = nth1*2;
7333
+ } else {
7334
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7335
+ }
7336
+
7337
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7338
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7339
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7340
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7341
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
7342
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
7343
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
7344
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
7345
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
7346
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
7347
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
7348
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
7349
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));
7350
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));
7351
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0));
7352
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1));
7353
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
7354
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
7355
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr));
7356
+ #endif
7357
+ break;
7358
+ }
5352
7359
  default:
5353
7360
  GGML_ASSERT(false && "not implemented");
5354
7361
  }
5355
7362
 
5356
- if (src0t == GGML_TYPE_Q4_0 ||
7363
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
5357
7364
  src0t == GGML_TYPE_Q4_1 ||
5358
7365
  src0t == GGML_TYPE_Q8_0 ||
5359
7366
  src0t == GGML_TYPE_Q2_K) {
@@ -5402,16 +7409,22 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
5402
7409
 
5403
7410
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5404
7411
 
7412
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5405
7413
  ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5406
7414
  ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
5407
7415
  ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5408
7416
 
7417
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5409
7418
  cl_ulong offset1 = extra1->offset + src1->view_offs;
5410
7419
  cl_ulong offset2 = extra2->offset + src2->view_offs;
5411
7420
  cl_ulong offsetd = extrad->offset + dst->view_offs;
5412
7421
 
7422
+ GGML_UNUSED(offset0);
7423
+
5413
7424
  #ifdef GGML_OPENCL_SOA_Q
5414
7425
  ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
7426
+ ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
7427
+ ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
5415
7428
  #endif
5416
7429
 
5417
7430
  const int ne00 = src0->ne[0];
@@ -5420,7 +7433,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
5420
7433
  const int ne03 = src0->ne[3];
5421
7434
 
5422
7435
  const cl_ulong nb00 = src0->nb[0];
7436
+ const cl_ulong nb01 = src0->nb[1];
5423
7437
  const cl_ulong nb02 = src0->nb[2];
7438
+ const cl_ulong nb03 = src0->nb[3];
5424
7439
 
5425
7440
  const int ne10 = src1->ne[0];
5426
7441
  const int ne11 = src1->ne[1];
@@ -5429,6 +7444,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
5429
7444
 
5430
7445
  const cl_ulong nb11 = src1->nb[1];
5431
7446
  const cl_ulong nb12 = src1->nb[2];
7447
+ const cl_ulong nb13 = src1->nb[3];
5432
7448
 
5433
7449
  const int ne20 = src2->ne[0];
5434
7450
  const int ne21 = src2->ne[1];
@@ -5496,6 +7512,170 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
5496
7512
 
5497
7513
  break;
5498
7514
  }
7515
+ case GGML_TYPE_Q8_0: {
7516
+ #ifdef GGML_OPENCL_SOA_Q
7517
+ kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32_flat;
7518
+
7519
+ if (backend_ctx->gpu_family == INTEL) {
7520
+ sgs = 16;
7521
+ nsg = 2;
7522
+ ndst = 4;
7523
+ } else if (backend_ctx->gpu_family == ADRENO) {
7524
+ sgs = 64;
7525
+ nsg = 2;
7526
+ ndst = 4;
7527
+ } else {
7528
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7529
+ }
7530
+
7531
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
7532
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
7533
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7534
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7535
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
7536
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
7537
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
7538
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
7539
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
7540
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
7541
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
7542
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
7543
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
7544
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
7545
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
7546
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
7547
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne20));
7548
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne21));
7549
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb21));
7550
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne0));
7551
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne1));
7552
+ #else
7553
+ kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32;
7554
+
7555
+ if (backend_ctx->gpu_family == INTEL) {
7556
+ sgs = 16;
7557
+ nsg = 2;
7558
+ ndst = 4;
7559
+ } else if (backend_ctx->gpu_family == ADRENO) {
7560
+ sgs = 64;
7561
+ nsg = 2;
7562
+ ndst = 4;
7563
+ } else {
7564
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7565
+ }
7566
+
7567
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7568
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7569
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7570
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7571
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
7572
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
7573
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
7574
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
7575
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
7576
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
7577
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
7578
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
7579
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
7580
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
7581
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
7582
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
7583
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne20));
7584
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne21));
7585
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb21));
7586
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne0));
7587
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne1));
7588
+ #endif // GGML_OPENCL_SOA_Q
7589
+ break;
7590
+ }
7591
+ case GGML_TYPE_MXFP4: {
7592
+ #ifdef GGML_OPENCL_SOA_Q
7593
+ kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
7594
+
7595
+ cl_mem q;
7596
+ if (backend_ctx->gpu_family == INTEL) {
7597
+ sgs = 16;
7598
+ nsg = 2;
7599
+ ndst = 2;
7600
+
7601
+ q = extra0_mxfp4->q;
7602
+ } else if (backend_ctx->gpu_family == ADRENO) {
7603
+ sgs = 64;
7604
+ nsg = 1;
7605
+ ndst = 4;
7606
+
7607
+ q = extra0_mxfp4->q_img;
7608
+ } else {
7609
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7610
+ }
7611
+
7612
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q));
7613
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e));
7614
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7615
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7616
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
7617
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
7618
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
7619
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
7620
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
7621
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
7622
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
7623
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
7624
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
7625
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
7626
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
7627
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
7628
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
7629
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne20));
7630
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne21));
7631
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));
7632
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0));
7633
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1));
7634
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
7635
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
7636
+ #else // GGML_OPENCL_SOA_Q
7637
+ kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32;
7638
+
7639
+ if (backend_ctx->gpu_family == INTEL) {
7640
+ sgs = 16;
7641
+ nsg = 2;
7642
+ ndst = 2;
7643
+ } else if (backend_ctx->gpu_family == ADRENO) {
7644
+ sgs = 64;
7645
+ nsg = 2;
7646
+ ndst = 2;
7647
+ } else {
7648
+ GGML_ASSERT(false && "TODO: Unknown GPU");
7649
+ }
7650
+
7651
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
7652
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
7653
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
7654
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
7655
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
7656
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
7657
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
7658
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
7659
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
7660
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
7661
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
7662
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
7663
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
7664
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
7665
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
7666
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
7667
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
7668
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne20));
7669
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne21));
7670
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));
7671
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0));
7672
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1));
7673
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
7674
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
7675
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr));
7676
+ #endif // GGML_OPENCL_SOA_Q
7677
+ break;
7678
+ }
5499
7679
  default:
5500
7680
  GGML_ASSERT(false && "not implemented");;
5501
7681
  }
@@ -5521,7 +7701,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
5521
7701
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5522
7702
 
5523
7703
  float scale;
5524
- memcpy(&scale, dst->op_params, sizeof(scale));
7704
+ float bias;
7705
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
7706
+ memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));
5525
7707
 
5526
7708
  ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5527
7709
  ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@@ -5536,6 +7718,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
5536
7718
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5537
7719
  CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5538
7720
  CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
7721
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
5539
7722
 
5540
7723
  int n = ggml_nelements(dst)/4;
5541
7724
 
@@ -5733,31 +7916,50 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5733
7916
  GGML_ASSERT(src1->extra);
5734
7917
  }
5735
7918
 
7919
+ const ggml_tensor * src2 = dst->src[2];
7920
+ if (src2) {
7921
+ GGML_ASSERT(src2->extra);
7922
+ }
7923
+
5736
7924
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5737
7925
 
5738
7926
  ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5739
7927
  ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5740
7928
 
5741
7929
  ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
7930
+ ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
5742
7931
 
5743
7932
  cl_ulong offset0 = extra0->offset + src0->view_offs;
5744
7933
  cl_ulong offsetd = extrad->offset + dst->view_offs;
5745
7934
 
5746
7935
  cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
7936
+ cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
5747
7937
 
5748
- const int ne00 = src0 ? src0->ne[0] : 0;
5749
- const int ne01 = src0 ? src0->ne[1] : 0;
5750
- const int ne02 = src0 ? src0->ne[2] : 0;
5751
- const int ne03 = src0 ? src0->ne[3] : 0;
7938
+ const int ne00 = src0->ne[0];
7939
+ const int ne01 = src0->ne[1];
7940
+ const int ne02 = src0->ne[2];
7941
+ const int ne03 = src0->ne[3];
7942
+
7943
+ const cl_long nb01 = src0->nb[1];
7944
+ const cl_long nb02 = src0->nb[2];
7945
+ const cl_long nb03 = src0->nb[3];
7946
+
7947
+ const int ne12 = src1 ? src1->ne[2] : 0;
7948
+ const int ne13 = src1 ? src1->ne[3] : 0;
7949
+
7950
+ const cl_long nb11 = src1 ? src1->nb[1] : 0;
7951
+ const cl_long nb12 = src1 ? src1->nb[2] : 0;
7952
+ const cl_long nb13 = src1 ? src1->nb[3] : 0;
7953
+
7954
+ const cl_long nb1 = dst->nb[1];
7955
+ const cl_long nb2 = dst->nb[2];
7956
+ const cl_long nb3 = dst->nb[3];
5752
7957
 
5753
7958
  float scale, max_bias;
5754
7959
  memcpy(&scale, dst->op_params + 0, sizeof(float));
5755
7960
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
5756
7961
 
5757
- const int nrows_x = ggml_nrows(src0);
5758
- const int nrows_y = src0->ne[1];
5759
-
5760
- const int n_head = nrows_x/nrows_y;
7962
+ const int n_head = src0->ne[2];
5761
7963
  const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
5762
7964
 
5763
7965
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -5799,16 +8001,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
5799
8001
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5800
8002
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
5801
8003
  CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5802
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5803
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5804
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5805
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5806
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5807
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
5808
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
5809
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
5810
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
5811
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
8004
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
8005
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
8006
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
8007
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
8008
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
8009
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
8010
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
8011
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
8012
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
8013
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
8014
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
8015
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
8016
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
8017
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
8018
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
8019
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
8020
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
8021
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
8022
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
8023
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
8024
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
5812
8025
 
5813
8026
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
5814
8027
  size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -6215,6 +8428,23 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
6215
8428
  kernel = backend_ctx->kernel_swiglu_f16;
6216
8429
  }
6217
8430
  break;
8431
+ case GGML_GLU_OP_SWIGLU_OAI:
8432
+ kernel = backend_ctx->kernel_swiglu_oai;
8433
+ break;
8434
+ case GGML_GLU_OP_GEGLU_ERF:
8435
+ if (dst->type == GGML_TYPE_F32) {
8436
+ kernel = backend_ctx->kernel_geglu_erf;
8437
+ } else {
8438
+ kernel = backend_ctx->kernel_geglu_erf_f16;
8439
+ }
8440
+ break;
8441
+ case GGML_GLU_OP_GEGLU_QUICK:
8442
+ if (dst->type == GGML_TYPE_F32) {
8443
+ kernel = backend_ctx->kernel_geglu_quick;
8444
+ } else {
8445
+ kernel = backend_ctx->kernel_geglu_quick_f16;
8446
+ }
8447
+ break;
6218
8448
  default:
6219
8449
  GGML_ABORT("Unsupported glu op");
6220
8450
  }
@@ -6236,7 +8466,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
6236
8466
 
6237
8467
  const cl_ulong nb1 = dst->nb[1];
6238
8468
 
6239
- const int swp = ((const int32_t *) dst->op_params)[1];
8469
+ const int swp = ggml_get_op_params_i32(dst, 1);
8470
+ const float alpha = ggml_get_op_params_f32(dst, 2);
8471
+ const float limit = ggml_get_op_params_f32(dst, 3);
8472
+
6240
8473
  const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
6241
8474
  const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
6242
8475
 
@@ -6253,6 +8486,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
6253
8486
  CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off));
6254
8487
  CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off));
6255
8488
 
8489
+ if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
8490
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));
8491
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));
8492
+ }
8493
+
6256
8494
  const size_t nrows = ggml_nrows(src0);
6257
8495
  size_t nth = 512;
6258
8496
  size_t global_work_size[] = {nrows*nth, 1, 1};
@@ -6284,6 +8522,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6284
8522
  }
6285
8523
  func = ggml_cl_get_rows;
6286
8524
  break;
8525
+ case GGML_OP_SET_ROWS:
8526
+ if (!any_on_device) {
8527
+ return false;
8528
+ }
8529
+ func = ggml_cl_set_rows;
8530
+ break;
6287
8531
  case GGML_OP_CPY:
6288
8532
  if (!any_on_device) {
6289
8533
  return false;
@@ -6303,6 +8547,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6303
8547
  }
6304
8548
  func = ggml_cl_add;
6305
8549
  break;
8550
+ case GGML_OP_ADD_ID:
8551
+ if (!any_on_device) {
8552
+ return false;
8553
+ }
8554
+ func = ggml_cl_add_id;
8555
+ break;
6306
8556
  case GGML_OP_MUL:
6307
8557
  if (!any_on_device) {
6308
8558
  return false;
@@ -6329,6 +8579,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6329
8579
  }
6330
8580
  func = ggml_cl_gelu;
6331
8581
  break;
8582
+ case GGML_UNARY_OP_GELU_ERF:
8583
+ if (!any_on_device) {
8584
+ return false;
8585
+ }
8586
+ func = ggml_cl_gelu_erf;
8587
+ break;
6332
8588
  case GGML_UNARY_OP_GELU_QUICK:
6333
8589
  if (!any_on_device) {
6334
8590
  return false;
@@ -6410,6 +8666,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6410
8666
  }
6411
8667
  ggml_cl_upscale(backend, tensor->src[0], tensor);
6412
8668
  return true;
8669
+ case GGML_OP_CONV_2D:
8670
+ if (!any_on_device) {
8671
+ return false;
8672
+ }
8673
+ func = ggml_cl_conv_2d;
8674
+ break;
6413
8675
  case GGML_OP_CONCAT:
6414
8676
  if (!any_on_device) {
6415
8677
  return false;
@@ -6485,6 +8747,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6485
8747
  }
6486
8748
  func = ggml_cl_sum_rows;
6487
8749
  break;
8750
+ case GGML_OP_FLASH_ATTN_EXT:
8751
+ if (!any_on_device) {
8752
+ return false;
8753
+ }
8754
+ ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
8755
+ return true;
6488
8756
  default:
6489
8757
  return false;
6490
8758
  }