whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -15,13 +15,12 @@
15
15
 
16
16
  #include <CL/cl.h>
17
17
 
18
+ #include <inttypes.h>
18
19
  #include <string.h>
19
20
 
20
21
  #include <cstddef>
21
22
  #include <cstdint>
22
- #include <atomic>
23
23
  #include <fstream>
24
- #include <limits>
25
24
  #include <vector>
26
25
  #include <string>
27
26
  #include <cmath>
@@ -54,6 +53,37 @@
54
53
 
55
54
  bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
56
55
 
56
+ // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
57
+ // Precompute mp (m' in the paper) and L such that division
58
+ // can be computed using a multiply (high 32b of 64b result)
59
+ // and a shift:
60
+ //
61
+ // n/d = (mulhi(n, mp) + n) >> L;
62
+ struct fastdiv_vals {
63
+ uint32_t mp;
64
+ uint32_t L;
65
+ uint32_t d;
66
+ uint32_t pad;
67
+ };
68
+ static_assert(sizeof(fastdiv_vals) == 16, "fastdiv_vals size incorrect");
69
+
70
+ static fastdiv_vals init_fastdiv_values(uint64_t d_64) {
71
+ GGML_ASSERT(d_64 != 0);
72
+ GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
73
+
74
+ uint32_t d = (uint32_t)d_64;
75
+
76
+ // compute L = ceil(log2(d));
77
+ uint32_t L = 0;
78
+ while (L < 32 && (uint32_t{ 1 } << L) < d) {
79
+ L++;
80
+ }
81
+
82
+ uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
83
+ // pack divisor as well to reduce error surface
84
+ return { mp, L, d, 0 };
85
+ }
86
+
57
87
  enum GPU_FAMILY {
58
88
  ADRENO,
59
89
  INTEL,
@@ -233,6 +263,32 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive
233
263
  return { type, major, minor, patch };
234
264
  }
235
265
 
266
+ // cl buffer wrapper
267
+ struct ggml_cl_buffer {
268
+ cl_mem buffer;
269
+ size_t size;
270
+
271
+ ggml_cl_buffer()
272
+ : buffer(nullptr), size(0) {}
273
+
274
+ ~ggml_cl_buffer() {
275
+ if (buffer) {
276
+ CL_CHECK(clReleaseMemObject(buffer));
277
+ }
278
+ }
279
+
280
+ void allocate(cl_context context, size_t new_size) {
281
+ if (new_size > size) {
282
+ size = new_size;
283
+ if (buffer) {
284
+ CL_CHECK(clReleaseMemObject(buffer));
285
+ }
286
+ cl_int err;
287
+ CL_CHECK((buffer = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err), err));
288
+ }
289
+ }
290
+ };
291
+
236
292
  // Profiling
237
293
  struct ProfilingInfo {
238
294
  std::string op_name;
@@ -346,6 +402,11 @@ struct ggml_backend_opencl_context {
346
402
  cl_context context;
347
403
  cl_command_queue queue;
348
404
 
405
+ // prealloc buffers for transposing weights and activations
406
+ ggml_cl_buffer prealloc_quant_trans;
407
+ ggml_cl_buffer prealloc_scales_trans;
408
+ ggml_cl_buffer prealloc_act_trans;
409
+
349
410
  cl_program program_add;
350
411
  cl_program program_add_id;
351
412
  cl_program program_clamp;
@@ -377,6 +438,8 @@ struct ggml_backend_opencl_context {
377
438
  cl_program program_mul_mv_f32_f32;
378
439
  cl_program program_mul;
379
440
  cl_program program_mul_mat_f16_f32_tiled;
441
+ cl_program program_mul_mm_f16_f32_kqv;
442
+ cl_program program_mul_mm_f16_f32_kq;
380
443
  cl_program program_div;
381
444
  cl_program program_sub;
382
445
  cl_program program_norm;
@@ -402,12 +465,14 @@ struct ggml_backend_opencl_context {
402
465
  cl_program program_conv_2d_f32;
403
466
  cl_program program_conv_2d_f16_f32;
404
467
  cl_program program_tsembd;
468
+ cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
405
469
  cl_program program_mul_mv_id_q4_0_f32_8x_flat;
406
470
  cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
407
471
  cl_program program_mul_mv_id_mxfp4_f32;
408
472
  cl_program program_mul_mv_id_mxfp4_f32_flat;
409
473
  cl_program program_mul_mm_f32_f32_l4_lm;
410
474
  cl_program program_mul_mm_f16_f32_l4_lm;
475
+ cl_program program_mul_mm_q8_0_f32_l4_lm;
411
476
 
412
477
  cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
413
478
  cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -415,12 +480,16 @@ struct ggml_backend_opencl_context {
415
480
  cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
416
481
  cl_kernel kernel_add_id;
417
482
  cl_kernel kernel_scale;
483
+ cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
484
+ cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
485
+ cl_kernel kernel_mean_f32;
418
486
  cl_kernel kernel_silu, kernel_silu_4;
419
487
  cl_kernel kernel_gelu, kernel_gelu_4;
420
488
  cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
421
489
  cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
422
490
  cl_kernel kernel_relu;
423
491
  cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
492
+ cl_kernel kernel_fill;
424
493
  cl_kernel kernel_clamp;
425
494
  cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
426
495
  kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
@@ -449,12 +518,15 @@ struct ggml_backend_opencl_context {
449
518
  cl_kernel kernel_mul_mat_f16_f32;
450
519
  cl_kernel kernel_mul_mat_f16_f32_l4;
451
520
  cl_kernel kernel_mul_mat_f16_f32_tiled;
521
+ cl_kernel kernel_mul_mm_f16_f32_kqv;
522
+ cl_kernel kernel_mul_mm_f16_f32_kq;
452
523
  cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
453
524
  cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
454
- cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
525
+ cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
455
526
  cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
456
527
  cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
457
528
  cl_kernel kernel_convert_block_q4_0_noshuffle;
529
+ cl_kernel kernel_restore_block_q4_0_noshuffle;
458
530
  cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
459
531
  cl_kernel kernel_mul_mv_q6_K_f32;
460
532
  cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
@@ -466,6 +538,10 @@ struct ggml_backend_opencl_context {
466
538
  cl_kernel kernel_pad;
467
539
  cl_kernel kernel_tanh_f32_nd;
468
540
  cl_kernel kernel_tanh_f16_nd;
541
+ cl_kernel kernel_expm1_f32_nd;
542
+ cl_kernel kernel_expm1_f16_nd;
543
+ cl_kernel kernel_softplus_f32_nd;
544
+ cl_kernel kernel_softplus_f16_nd;
469
545
  cl_kernel kernel_upscale;
470
546
  cl_kernel kernel_upscale_bilinear;
471
547
  cl_kernel kernel_concat_f32_contiguous;
@@ -473,13 +549,16 @@ struct ggml_backend_opencl_context {
473
549
  cl_kernel kernel_conv_2d_f16;
474
550
  cl_kernel kernel_conv_2d_f32;
475
551
  cl_kernel kernel_conv_2d_f16_f32;
552
+ cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
476
553
  cl_kernel kernel_timestep_embedding;
554
+ cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
477
555
  cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
478
556
  cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
479
557
  cl_kernel kernel_mul_mv_id_mxfp4_f32;
480
558
  cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
481
559
  cl_kernel kernel_mul_mm_f32_f32_l4_lm;
482
560
  cl_kernel kernel_mul_mm_f16_f32_l4_lm;
561
+ cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
483
562
 
484
563
  std::vector<ProfilingInfo> profiling_info;
485
564
 
@@ -529,25 +608,17 @@ struct ggml_backend_opencl_context {
529
608
  }
530
609
 
531
610
  // Dump a csv
532
- float total_kernel_time = 0;
533
- fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
611
+ fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n");
534
612
  for (const ProfilingInfo & info : profiling_info) {
535
- total_kernel_time += info.cmd_duration_ns/1.e6f;
536
- fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
613
+ fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
537
614
  info.op_name.c_str(), info.kernel_name.c_str(),
538
- info.cmd_queued_duration_ns/1.e6f,
539
- info.cmd_submit_duration_ns/1.e6f,
540
615
  info.cmd_duration_ns/1.e6f,
541
- info.cmd_complete_duration_ns/1.e6f,
542
- info.cmd_total_duration_ns/1.e6f,
543
616
  info.global_size[0], info.global_size[1], info.global_size[2],
544
617
  info.local_size[0], info.local_size[1], info.local_size[2],
545
618
  info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
546
619
  }
547
620
  fclose(fperf);
548
621
 
549
- GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
550
-
551
622
  // Dump a simple chrome trace
552
623
  FILE* ftrace = fopen("cl_trace.json", "w");
553
624
  if (!ftrace) {
@@ -557,14 +628,14 @@ struct ggml_backend_opencl_context {
557
628
 
558
629
  fprintf(ftrace, "[\n");
559
630
  for (const ProfilingInfo & info : profiling_info) {
560
- fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
631
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n",
561
632
  info.kernel_name.c_str(), info.cmd_queued/1000);
562
- fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
633
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n",
563
634
  info.kernel_name.c_str(), info.cmd_submit/1000);
564
635
 
565
- fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
636
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n",
566
637
  info.kernel_name.c_str(), info.cmd_start/1000);
567
- fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
638
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n",
568
639
  info.kernel_name.c_str(), info.cmd_end/1000);
569
640
  }
570
641
  fclose(ftrace);
@@ -600,12 +671,9 @@ struct ggml_backend_opencl_context {
600
671
  cl_kernel kernel_transpose_32;
601
672
  cl_kernel kernel_transpose_32_16;
602
673
  cl_kernel kernel_transpose_16;
674
+ cl_kernel kernel_transpose_16_buf;
603
675
  cl_kernel kernel_transpose_16_4x1;
604
676
 
605
- cl_mem A_s_d_max; // max scale buffer size for transpose
606
- cl_mem A_q_d_max; // max weight buffer size for transpose
607
- cl_mem B_d_max; // max activation buffer size for transpose
608
-
609
677
  // Gemm and Gemv related programs, kernels, etc
610
678
  cl_program program_CL_gemm;
611
679
  cl_program program_CL_gemv_general;
@@ -724,6 +792,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
724
792
  GGML_LOG_CONT(".");
725
793
  }
726
794
 
795
+ // fill
796
+ {
797
+ #ifdef GGML_OPENCL_EMBED_KERNELS
798
+ const std::string kernel_src {
799
+ #include "fill.cl.h"
800
+ };
801
+ #else
802
+ const std::string kernel_src = read_file("fill.cl");
803
+ #endif
804
+ cl_program prog =
805
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
806
+
807
+ CL_CHECK((backend_ctx->kernel_fill = clCreateKernel(prog, "kernel_fill_f32", &err), err));
808
+ GGML_LOG_CONT(".");
809
+
810
+ CL_CHECK(clReleaseProgram(prog));
811
+ }
812
+
727
813
  // clamp
728
814
  {
729
815
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -772,9 +858,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
772
858
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
773
859
 
774
860
  CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err));
861
+ CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err));
775
862
  CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
776
863
  CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
777
864
  CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
865
+ CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
866
+ CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
778
867
  CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
779
868
  CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
780
869
  CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
@@ -1191,6 +1280,41 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1191
1280
  GGML_LOG_CONT(".");
1192
1281
  }
1193
1282
 
1283
+ // mul_mm_q8_0_f32_l4_lm
1284
+ {
1285
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1286
+ const std::string kernel_src {
1287
+ #include "mul_mm_q8_0_f32_l4_lm.cl.h"
1288
+ };
1289
+ #else
1290
+ const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl");
1291
+ #endif
1292
+ backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
1293
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1294
+
1295
+ CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err));
1296
+ GGML_LOG_CONT(".");
1297
+ }
1298
+
1299
+ // mul_mm_f16_f32_kq_kqv
1300
+ {
1301
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1302
+ const std::string kernel_src {
1303
+ #include "mul_mm_f16_f32_kq_kqv.cl.h"
1304
+ };
1305
+ #else
1306
+ const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl");
1307
+ #endif
1308
+ backend_ctx->program_mul_mm_f16_f32_kqv =
1309
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV ");
1310
+ backend_ctx->program_mul_mm_f16_f32_kq =
1311
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1312
+
1313
+ CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err));
1314
+ CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err));
1315
+ GGML_LOG_CONT(".");
1316
+ }
1317
+
1194
1318
  // mul
1195
1319
  {
1196
1320
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1485,6 +1609,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1485
1609
  GGML_LOG_CONT(".");
1486
1610
  }
1487
1611
 
1612
+ // sqr
1613
+ {
1614
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1615
+ const std::string kernel_src {
1616
+ #include "sqr.cl.h"
1617
+ };
1618
+ #else
1619
+ const std::string kernel_src = read_file("sqr.cl");
1620
+ #endif
1621
+ cl_program prog =
1622
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1623
+
1624
+ CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err));
1625
+ CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err));
1626
+ CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err));
1627
+ CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err));
1628
+
1629
+ CL_CHECK(clReleaseProgram(prog));
1630
+ GGML_LOG_CONT(".");
1631
+ }
1632
+
1633
+ // sqrt
1634
+ {
1635
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1636
+ const std::string kernel_src {
1637
+ #include "sqrt.cl.h"
1638
+ };
1639
+ #else
1640
+ const std::string kernel_src = read_file("sqrt.cl");
1641
+ #endif
1642
+ cl_program prog =
1643
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1644
+
1645
+ CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err));
1646
+ CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err));
1647
+ CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err));
1648
+ CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err));
1649
+
1650
+ CL_CHECK(clReleaseProgram(prog));
1651
+ GGML_LOG_CONT(".");
1652
+ }
1653
+
1654
+ // mean
1655
+ {
1656
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1657
+ const std::string kernel_src {
1658
+ #include "mean.cl.h"
1659
+ };
1660
+ #else
1661
+ const std::string kernel_src = read_file("mean.cl");
1662
+ #endif
1663
+ cl_program prog =
1664
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1665
+
1666
+ CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
1667
+
1668
+ CL_CHECK(clReleaseProgram(prog));
1669
+ GGML_LOG_CONT(".");
1670
+ }
1671
+
1488
1672
  // sub
1489
1673
  {
1490
1674
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1619,6 +1803,56 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1619
1803
  }
1620
1804
  }
1621
1805
 
1806
+ // expm1
1807
+ {
1808
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1809
+ const std::string kernel_src {
1810
+ #include "expm1.cl.h"
1811
+ };
1812
+ #else
1813
+ const std::string kernel_src = read_file("expm1.cl");
1814
+ #endif
1815
+ cl_program prog;
1816
+ if (!kernel_src.empty()) {
1817
+ prog =
1818
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1819
+ CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err));
1820
+ CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err));
1821
+ GGML_LOG_CONT(".");
1822
+ } else {
1823
+ GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n");
1824
+ prog = nullptr;
1825
+ backend_ctx->kernel_expm1_f32_nd = nullptr;
1826
+ backend_ctx->kernel_expm1_f16_nd = nullptr;
1827
+ }
1828
+ CL_CHECK(clReleaseProgram(prog));
1829
+ }
1830
+
1831
+ // softplus
1832
+ {
1833
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1834
+ const std::string kernel_src {
1835
+ #include "softplus.cl.h"
1836
+ };
1837
+ #else
1838
+ const std::string kernel_src = read_file("softplus.cl");
1839
+ #endif
1840
+ cl_program prog;
1841
+ if (!kernel_src.empty()) {
1842
+ prog =
1843
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1844
+ CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
1845
+ CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
1846
+ GGML_LOG_CONT(".");
1847
+ } else {
1848
+ GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
1849
+ prog = nullptr;
1850
+ backend_ctx->kernel_softplus_f32_nd = nullptr;
1851
+ backend_ctx->kernel_softplus_f16_nd = nullptr;
1852
+ }
1853
+ CL_CHECK(clReleaseProgram(prog));
1854
+ }
1855
+
1622
1856
  // upscale
1623
1857
  {
1624
1858
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1758,6 +1992,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1758
1992
  }
1759
1993
  }
1760
1994
 
1995
+ // ssm_conv
1996
+ {
1997
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1998
+ const std::string kernel_src {
1999
+ #include "ssm_conv.cl.h"
2000
+ };
2001
+ #else
2002
+ const std::string kernel_src = read_file("ssm_conv.cl");
2003
+ #endif
2004
+ cl_program prog =
2005
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
2006
+
2007
+ CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err));
2008
+ CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err));
2009
+ CL_CHECK(clReleaseProgram(prog));
2010
+ GGML_LOG_CONT(".");
2011
+ }
2012
+
1761
2013
  // mul_mv_id_q4_0_f32_8x_flat
1762
2014
  {
1763
2015
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1855,7 +2107,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1855
2107
  CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err));
1856
2108
  CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err));
1857
2109
  CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err));
1858
- CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err));
2110
+ CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err));
2111
+ CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err));
1859
2112
  GGML_LOG_CONT(".");
1860
2113
  }
1861
2114
 
@@ -1973,6 +2226,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1973
2226
  CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
1974
2227
  GGML_LOG_CONT(".");
1975
2228
  }
2229
+
2230
+ std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
2231
+ " -cl-mad-enable "
2232
+ " -cl-fast-relaxed-math";
2233
+
2234
+ // gemv_moe_mxfp4_f32
2235
+ {
2236
+ #ifdef GGML_OPENCL_EMBED_KERNELS
2237
+ const std::string kernel_src {
2238
+ #include "gemv_moe_mxfp4_f32.cl.h"
2239
+ };
2240
+ #else
2241
+ const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl");
2242
+ #endif
2243
+ backend_ctx->program_gemv_moe_mxfp4_f32 =
2244
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
2245
+
2246
+ CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err));
2247
+ GGML_LOG_CONT(".");
2248
+ }
2249
+
2250
+ // gemm_moe_mxfp4_f32
2251
+ {
2252
+ #ifdef GGML_OPENCL_EMBED_KERNELS
2253
+ const std::string kernel_src {
2254
+ #include "gemm_moe_mxfp4_f32.cl.h"
2255
+ };
2256
+ #else
2257
+ const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl");
2258
+ #endif
2259
+ backend_ctx->program_gemm_moe_mxfp4_f32 =
2260
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
2261
+
2262
+ CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
2263
+ GGML_LOG_CONT(".");
2264
+ }
1976
2265
  #endif // GGML_OPENCL_USE_ADRENO_KERNELS
1977
2266
  GGML_LOG_CONT("\n");
1978
2267
  }
@@ -2348,8 +2637,13 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
2348
2637
  svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
2349
2638
 
2350
2639
  if (opencl_c_version.major >= 3) {
2640
+ // Assume it is not available for 3.0, since it is optional in 3.0.
2641
+ // If compiling against 3.0, then we can query.
2642
+ backend_ctx->non_uniform_workgroups = false;
2643
+ #if CL_TARGET_OPENCL_VERSION >= 300
2351
2644
  CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),
2352
2645
  &backend_ctx->non_uniform_workgroups, 0));
2646
+ #endif
2353
2647
  } else {
2354
2648
  GGML_ASSERT(opencl_c_version.major == 2);
2355
2649
  // Non-uniform workgroup sizes is mandatory feature in v2.x.
@@ -2406,9 +2700,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
2406
2700
  required_B_d_bytes, max_B_d_bytes);
2407
2701
  }
2408
2702
 
2409
- CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
2410
- CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));
2411
- CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
2703
+ backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes);
2704
+ backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes);
2705
+ backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes);
2412
2706
  #endif // GGML_OPENCL_USE_ADRENO_KERNELS
2413
2707
 
2414
2708
  backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
@@ -2681,7 +2975,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
2681
2975
 
2682
2976
  // if rms_norm is the B operand, then we don't handle broadcast
2683
2977
  if (rms_norm == mul->src[1] &&
2684
- !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2978
+ !ggml_are_same_shape(mul->src[0], rms_norm)) {
2685
2979
  return false;
2686
2980
  }
2687
2981
 
@@ -2851,6 +3145,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2851
3145
  (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
2852
3146
  case GGML_OP_ADD_ID:
2853
3147
  return op->src[0]->type == GGML_TYPE_F32;
3148
+ case GGML_OP_SQR:
3149
+ case GGML_OP_SQRT:
3150
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
3151
+ ggml_is_contiguous(op->src[0]);
2854
3152
  case GGML_OP_UNARY:
2855
3153
  switch (ggml_get_unary_op(op)) {
2856
3154
  case GGML_UNARY_OP_GELU:
@@ -2864,6 +3162,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2864
3162
  case GGML_UNARY_OP_TANH:
2865
3163
  return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2866
3164
  (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
3165
+ case GGML_UNARY_OP_EXPM1:
3166
+ return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
3167
+ (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
3168
+ case GGML_UNARY_OP_SOFTPLUS:
3169
+ return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
3170
+ (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
2867
3171
  default:
2868
3172
  return false;
2869
3173
  }
@@ -2879,6 +3183,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2879
3183
  default:
2880
3184
  return false;
2881
3185
  }
3186
+ case GGML_OP_FILL:
3187
+ return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
2882
3188
  case GGML_OP_CLAMP:
2883
3189
  return op->src[0]->type == GGML_TYPE_F32;
2884
3190
  case GGML_OP_SOFT_MAX:
@@ -2889,16 +3195,23 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2889
3195
  case GGML_OP_REPEAT:
2890
3196
  return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
2891
3197
  case GGML_OP_PAD:
2892
- return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
2893
- op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
2894
- (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
2895
- (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
2896
- case GGML_OP_UPSCALE:
3198
+ // TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985
3199
+ if (ggml_get_op_params_i32(op, 8) != 0) {
3200
+ return false;
3201
+ }
2897
3202
  return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
3203
+ case GGML_OP_UPSCALE: {
3204
+ ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
3205
+ const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS);
3206
+ return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
3207
+ (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias;
3208
+ }
2898
3209
  case GGML_OP_CONV_2D:
2899
3210
  return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
2900
3211
  (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2901
3212
  (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
3213
+ case GGML_OP_SSM_CONV:
3214
+ return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
2902
3215
  case GGML_OP_CONCAT:
2903
3216
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2904
3217
  case GGML_OP_TIMESTEP_EMBEDDING:
@@ -2967,6 +3280,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2967
3280
  return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
2968
3281
  }
2969
3282
  case GGML_OP_SUM_ROWS:
3283
+ case GGML_OP_MEAN:
2970
3284
  return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2971
3285
  case GGML_OP_FLASH_ATTN_EXT:
2972
3286
  {
@@ -3279,6 +3593,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
3279
3593
  tensor->ne[2] == 1 && tensor->ne[3] == 1;
3280
3594
  }
3281
3595
 
3596
+ inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
3597
+ GGML_UNUSED(backend_ctx);
3598
+ int ne01 = tensor->ne[1];
3599
+ return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
3600
+ }
3601
+
3282
3602
  static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
3283
3603
  ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
3284
3604
 
@@ -3395,32 +3715,35 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
3395
3715
  // use sub_buffer of max buffer size instead
3396
3716
 
3397
3717
  size_t q_size_bytes = K * M / 8 * sizeof(float);
3718
+ backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes);
3719
+
3398
3720
  cl_buffer_region region;
3399
3721
  region.origin = 0;
3400
3722
  region.size = q_size_bytes;
3401
3723
  cl_mem qT_d = clCreateSubBuffer(
3402
- backend_ctx->A_q_d_max,
3724
+ backend_ctx->prealloc_quant_trans.buffer,
3403
3725
  0,
3404
3726
  CL_BUFFER_CREATE_TYPE_REGION,
3405
3727
  &region,
3406
3728
  &err);
3407
- // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
3408
3729
  CL_CHECK(err);
3409
3730
 
3410
3731
  bool K_tile_trans = true;
3411
3732
  if ((K / 32) % 4 != 0){
3412
3733
  K_tile_trans =false;
3413
3734
  }
3735
+
3414
3736
  size_t d_size_bytes = M * (K / 32) * 2;
3737
+ backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes);
3738
+
3415
3739
  region.origin = 0;
3416
3740
  region.size = d_size_bytes;
3417
3741
  cl_mem dT_d = clCreateSubBuffer(
3418
- backend_ctx->A_s_d_max,
3742
+ backend_ctx->prealloc_scales_trans.buffer,
3419
3743
  0,
3420
3744
  CL_BUFFER_CREATE_TYPE_REGION,
3421
3745
  &region,
3422
3746
  &err);
3423
- // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
3424
3747
  CL_CHECK(err);
3425
3748
 
3426
3749
  // <----------------------------------------------------------------------------------> //
@@ -3581,14 +3904,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
3581
3904
  CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
3582
3905
  CL_CHECK(err);
3583
3906
 
3907
+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3908
+ if (use_adreno_moe_kernels(backend_ctx, tensor)) {
3909
+ cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
3910
+
3911
+ int ne00 = tensor->ne[0];
3912
+ int ne01 = tensor->ne[1];
3913
+ int ne02 = tensor->ne[2];
3914
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
3915
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
3916
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
3917
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
3918
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
3919
+
3920
+ size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
3921
+ size_t local_work_size[3] = {64, 2, 1};
3922
+
3923
+ cl_event evt;
3924
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
3925
+ CL_CHECK(clWaitForEvents(1, &evt));
3926
+ CL_CHECK(clReleaseMemObject(data_device));
3927
+ tensor->extra = extra;
3928
+
3929
+ return;
3930
+ }
3931
+ #endif
3584
3932
  cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
3585
3933
 
3586
3934
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
3587
3935
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
3588
3936
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
3589
3937
 
3590
- size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3591
- size_t local_work_size[] = {64, 1, 1};
3938
+ size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3939
+ size_t local_work_size[3] = {64, 1, 1};
3592
3940
 
3593
3941
  cl_event evt;
3594
3942
  CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
@@ -3604,7 +3952,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
3604
3952
  { extra->q }
3605
3953
  };
3606
3954
  extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
3607
-
3608
3955
  tensor->extra = extra;
3609
3956
 
3610
3957
  return;
@@ -3701,6 +4048,91 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
3701
4048
  if (tensor->type == GGML_TYPE_Q4_0) {
3702
4049
  ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra;
3703
4050
 
4051
+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
4052
+ if (use_adreno_kernels(backend_ctx, tensor)) {
4053
+ cl_int err;
4054
+ cl_kernel kernel;
4055
+
4056
+ cl_int M = tensor->ne[1]; // ne01
4057
+ cl_int K = tensor->ne[0]; // ne00
4058
+
4059
+ GGML_ASSERT(K % 32 == 0);
4060
+ GGML_ASSERT(M % 4 == 0);
4061
+
4062
+ size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2;
4063
+ size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);
4064
+ GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
4065
+
4066
+ cl_mem buf_trans_q;
4067
+ cl_mem buf_trans_d;
4068
+
4069
+ CL_CHECK((buf_trans_q = clCreateBuffer(context, CL_MEM_READ_WRITE,
4070
+ size_q, NULL, &err), err));
4071
+ CL_CHECK((buf_trans_d = clCreateBuffer(context, CL_MEM_READ_WRITE,
4072
+ size_d, NULL, &err), err));
4073
+
4074
+ kernel = backend_ctx->kernel_transpose_16_buf;
4075
+
4076
+ // transpose q back
4077
+ cl_int stride_k_q = K/4;
4078
+ size_t local_size_q[3] = {64, 1, 1};
4079
+ size_t global_size_q[3] = {(size_t)M, (size_t)stride_k_q, 1};
4080
+
4081
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
4082
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_q));
4083
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M));
4084
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_q));
4085
+
4086
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
4087
+ global_size_q, local_size_q, 0, NULL, NULL));
4088
+
4089
+ // transpose scales back
4090
+ cl_int stride_k_d = K/32;
4091
+ size_t local_size_d[3] = {64, 1, 1};
4092
+ size_t global_size_d[3] = {(size_t)M, (size_t)stride_k_d, 1};
4093
+
4094
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->d));
4095
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d));
4096
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M));
4097
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_d));
4098
+
4099
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
4100
+ global_size_d, local_size_d, 0, NULL, NULL));
4101
+
4102
+ // unpack
4103
+ cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
4104
+ ggml_nbytes(tensor), NULL, &err);
4105
+ CL_CHECK(err);
4106
+
4107
+ cl_uchar mask_0F = 0x0F;
4108
+ cl_uchar mask_F0 = 0xF0;
4109
+
4110
+ size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
4111
+ size_t local_work_size[] = {1, 1, 1};
4112
+
4113
+ kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle;
4114
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q));
4115
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d));
4116
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
4117
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F));
4118
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0));
4119
+
4120
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
4121
+ global_work_size, local_work_size, 0, NULL, NULL));
4122
+
4123
+ // read back to host
4124
+ CL_CHECK(clEnqueueReadBuffer(
4125
+ queue, data_device, CL_TRUE, offset,
4126
+ size, data, 0, NULL, NULL));
4127
+
4128
+ CL_CHECK(clReleaseMemObject(data_device));
4129
+ CL_CHECK(clReleaseMemObject(buf_trans_q));
4130
+ CL_CHECK(clReleaseMemObject(buf_trans_d));
4131
+
4132
+ return;
4133
+ }
4134
+ #endif
4135
+
3704
4136
  cl_int err;
3705
4137
  cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
3706
4138
  ggml_nbytes(tensor), NULL, &err);
@@ -3731,6 +4163,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
3731
4163
  ggml_nbytes(tensor), NULL, &err);
3732
4164
  CL_CHECK(err);
3733
4165
 
4166
+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
4167
+ if (use_adreno_moe_kernels(backend_ctx, tensor)) {
4168
+ cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
4169
+
4170
+ int ne00 = tensor->ne[0];
4171
+ int ne01 = tensor->ne[1];
4172
+ int ne02 = tensor->ne[2];
4173
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
4174
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
4175
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
4176
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
4177
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
4178
+
4179
+ size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
4180
+ size_t local_work_size[3] = {64, 2, 1};
4181
+
4182
+ cl_event evt;
4183
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
4184
+ global_work_size, local_work_size, 0, NULL, &evt));
4185
+ CL_CHECK(clWaitForEvents(1, &evt));
4186
+ CL_CHECK(clEnqueueReadBuffer(
4187
+ queue, data_device, CL_TRUE, offset,
4188
+ size, data, 0, NULL, NULL));
4189
+ CL_CHECK(clReleaseMemObject(data_device));
4190
+ return;
4191
+ }
4192
+ #endif
3734
4193
  cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
3735
4194
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
3736
4195
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -3888,8 +4347,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
3888
4347
  }
3889
4348
 
3890
4349
  static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
3891
- *free = 1;
3892
- *total = 1;
4350
+ *free = 0;
4351
+ *total = 0;
3893
4352
 
3894
4353
  GGML_UNUSED(dev);
3895
4354
  }
@@ -4222,15 +4681,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
4222
4681
  GGML_ASSERT(dst);
4223
4682
  GGML_ASSERT(dst->extra);
4224
4683
 
4225
- const int ne00 = src0 ? src0->ne[0] : 0;
4226
- const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
4227
- const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
4228
- const int ne10 = src1 ? src1->ne[0] : 0;
4229
- const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
4230
- const int ne11 = src1 ? src1->ne[1] : 0;
4231
- const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
4232
- const cl_ulong nb1 = dst ? dst->nb[1] : 0;
4233
- const cl_ulong nb2 = dst ? dst->nb[2] : 0;
4684
+ const int ne00 = src0->ne[0];
4685
+ const cl_ulong nb01 = src0->nb[1];
4686
+ const cl_ulong nb02 = src0->nb[2];
4687
+ const cl_ulong nb03 = src0->nb[3];
4688
+ const int ne10 = src1->ne[0];
4689
+ const cl_ulong nb10 = src1->nb[0];
4690
+ const int ne11 = src1->ne[1];
4691
+ const int ne12 = src1->ne[2];
4692
+ const cl_ulong nb11 = src1->nb[1];
4693
+ const cl_ulong nb12 = src1->nb[2];
4694
+ const cl_ulong nb1 = dst->nb[1];
4695
+ const cl_ulong nb2 = dst->nb[2];
4696
+ const cl_ulong nb3 = dst->nb[3];
4234
4697
 
4235
4698
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4236
4699
 
@@ -4267,14 +4730,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
4267
4730
  CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
4268
4731
  CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
4269
4732
  CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
4270
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
4271
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
4272
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
4273
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
4274
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
4275
-
4276
- size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
4277
- size_t local_work_size[] = {1, 1, 1};
4733
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
4734
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
4735
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
4736
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
4737
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
4738
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
4739
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
4740
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
4741
+
4742
+ size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
4743
+ size_t local_work_size[] = {64, 1, 1};
4278
4744
 
4279
4745
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4280
4746
  }
@@ -4346,6 +4812,9 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
4346
4812
  GGML_ABORT("not implemented");
4347
4813
  }
4348
4814
 
4815
+ fastdiv_vals ne11_ = init_fastdiv_values(ne11);
4816
+ fastdiv_vals ne12_ = init_fastdiv_values(ne12);
4817
+
4349
4818
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4350
4819
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4351
4820
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
@@ -4356,8 +4825,8 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
4356
4825
  CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
4357
4826
  CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
4358
4827
  CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
4359
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
4360
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
4828
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_));
4829
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_));
4361
4830
  CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
4362
4831
  CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
4363
4832
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));
@@ -5018,12 +5487,11 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
5018
5487
  }
5019
5488
  }
5020
5489
 
5021
- static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5490
+ static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5022
5491
  GGML_ASSERT(src0);
5023
5492
  GGML_ASSERT(src0->extra);
5024
5493
  GGML_ASSERT(dst);
5025
5494
  GGML_ASSERT(dst->extra);
5026
-
5027
5495
  UNUSED(src1);
5028
5496
 
5029
5497
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -5036,13 +5504,21 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
5036
5504
 
5037
5505
  cl_kernel kernel;
5038
5506
 
5507
+ // Currently assumes src0 is contiguous
5039
5508
  int n = ggml_nelements(dst);
5040
-
5041
5509
  if (n % 4 == 0) {
5042
- kernel = backend_ctx->kernel_gelu_4;
5510
+ if (src0->type == GGML_TYPE_F32) {
5511
+ kernel = backend_ctx->kernel_sqr_cont_f32_4;
5512
+ } else {
5513
+ kernel = backend_ctx->kernel_sqr_cont_f16_4;
5514
+ }
5043
5515
  n /= 4;
5044
5516
  } else {
5045
- kernel = backend_ctx->kernel_gelu;
5517
+ if (src0->type == GGML_TYPE_F32) {
5518
+ kernel = backend_ctx->kernel_sqr_cont_f32;
5519
+ } else {
5520
+ kernel = backend_ctx->kernel_sqr_cont_f16;
5521
+ }
5046
5522
  }
5047
5523
 
5048
5524
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
@@ -5053,15 +5529,19 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
5053
5529
  size_t global_work_size[] = {(size_t)n, 1, 1};
5054
5530
  size_t local_work_size[] = {64, 1, 1};
5055
5531
 
5056
- backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5532
+ size_t * local_work_size_ptr = local_work_size;
5533
+ if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
5534
+ local_work_size_ptr = nullptr;
5535
+ }
5536
+
5537
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
5057
5538
  }
5058
5539
 
5059
- static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5540
+ static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5060
5541
  GGML_ASSERT(src0);
5061
5542
  GGML_ASSERT(src0->extra);
5062
5543
  GGML_ASSERT(dst);
5063
5544
  GGML_ASSERT(dst->extra);
5064
-
5065
5545
  UNUSED(src1);
5066
5546
 
5067
5547
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -5074,14 +5554,221 @@ static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, c
5074
5554
 
5075
5555
  cl_kernel kernel;
5076
5556
 
5557
+ // Currently assumes src0 is contiguous
5077
5558
  int n = ggml_nelements(dst);
5078
-
5079
5559
  if (n % 4 == 0) {
5080
- kernel = backend_ctx->kernel_gelu_erf_4;
5560
+ if (src0->type == GGML_TYPE_F32) {
5561
+ kernel = backend_ctx->kernel_sqrt_cont_f32_4;
5562
+ } else {
5563
+ kernel = backend_ctx->kernel_sqrt_cont_f16_4;
5564
+ }
5081
5565
  n /= 4;
5082
5566
  } else {
5083
- kernel = backend_ctx->kernel_gelu_erf;
5084
- }
5567
+ if (src0->type == GGML_TYPE_F32) {
5568
+ kernel = backend_ctx->kernel_sqrt_cont_f32;
5569
+ } else {
5570
+ kernel = backend_ctx->kernel_sqrt_cont_f16;
5571
+ }
5572
+ }
5573
+
5574
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5575
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5576
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5577
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5578
+
5579
+ size_t global_work_size[] = {(size_t)n, 1, 1};
5580
+ size_t local_work_size[] = {64, 1, 1};
5581
+
5582
+ size_t * local_work_size_ptr = local_work_size;
5583
+ if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
5584
+ local_work_size_ptr = nullptr;
5585
+ }
5586
+
5587
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
5588
+ }
5589
+
5590
+ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5591
+ GGML_ASSERT(src0);
5592
+ GGML_ASSERT(src0->extra);
5593
+ GGML_ASSERT(dst);
5594
+ GGML_ASSERT(dst->extra);
5595
+ GGML_UNUSED(src1);
5596
+
5597
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
5598
+ GGML_ASSERT(ggml_is_contiguous(src0));
5599
+
5600
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5601
+
5602
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5603
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5604
+
5605
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5606
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5607
+
5608
+ const int ne00 = src0->ne[0];
5609
+ const int ne01 = src0->ne[1];
5610
+ const int ne02 = src0->ne[2];
5611
+ const int ne03 = src0->ne[3];
5612
+
5613
+ const cl_ulong nb01 = src0->nb[1];
5614
+ const cl_ulong nb02 = src0->nb[2];
5615
+ const cl_ulong nb03 = src0->nb[3];
5616
+
5617
+ const cl_ulong nb1 = dst->nb[1];
5618
+ const cl_ulong nb2 = dst->nb[2];
5619
+ const cl_ulong nb3 = dst->nb[3];
5620
+
5621
+ cl_kernel kernel = backend_ctx->kernel_mean_f32;
5622
+
5623
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5624
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5625
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5626
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5627
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
5628
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
5629
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
5630
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
5631
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
5632
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
5633
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
5634
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1));
5635
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
5636
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
5637
+
5638
+ size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
5639
+ size_t local_work_size[] = {(size_t)64, 1, 1};
5640
+
5641
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5642
+ }
5643
+
5644
+ static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5645
+ GGML_ASSERT(src0);
5646
+ GGML_ASSERT(src0->extra);
5647
+ GGML_ASSERT(src1);
5648
+ GGML_ASSERT(src1->extra);
5649
+ GGML_ASSERT(dst);
5650
+ GGML_ASSERT(dst->extra);
5651
+
5652
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5653
+
5654
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5655
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5656
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5657
+
5658
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5659
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
5660
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5661
+
5662
+ int ne01 = src0->ne[1];
5663
+ cl_ulong nb00 = src0->nb[0];
5664
+ cl_ulong nb01 = src0->nb[1];
5665
+ cl_ulong nb02 = src0->nb[2];
5666
+
5667
+ int ne10 = src1->ne[0];
5668
+ cl_ulong nb11 = src1->nb[1];
5669
+
5670
+ int ne1 = dst->ne[1];
5671
+ int ne2 = dst->ne[2];
5672
+ cl_ulong nb0 = dst->nb[0];
5673
+ cl_ulong nb1 = dst->nb[1];
5674
+ cl_ulong nb2 = dst->nb[2];
5675
+
5676
+ cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32;
5677
+
5678
+ if (ne10 % 4 == 0) {
5679
+ kernel = backend_ctx->kernel_ssm_conv_f32_f32_4;
5680
+ }
5681
+
5682
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5683
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5684
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5685
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5686
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5687
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5688
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00));
5689
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
5690
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
5691
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
5692
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
5693
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0));
5694
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
5695
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
5696
+
5697
+ size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2};
5698
+ size_t local_work_size[] = {64, 1, 1};
5699
+
5700
+ size_t * local_work_size_ptr = local_work_size;
5701
+ if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
5702
+ local_work_size_ptr = nullptr;
5703
+ }
5704
+
5705
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
5706
+ }
5707
+
5708
+ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5709
+ GGML_ASSERT(src0);
5710
+ GGML_ASSERT(src0->extra);
5711
+ GGML_ASSERT(dst);
5712
+ GGML_ASSERT(dst->extra);
5713
+
5714
+ UNUSED(src1);
5715
+
5716
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5717
+
5718
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5719
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5720
+
5721
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5722
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5723
+
5724
+ cl_kernel kernel;
5725
+
5726
+ int n = ggml_nelements(dst);
5727
+
5728
+ if (n % 4 == 0) {
5729
+ kernel = backend_ctx->kernel_gelu_4;
5730
+ n /= 4;
5731
+ } else {
5732
+ kernel = backend_ctx->kernel_gelu;
5733
+ }
5734
+
5735
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5736
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5737
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5738
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5739
+
5740
+ size_t global_work_size[] = {(size_t)n, 1, 1};
5741
+ size_t local_work_size[] = {64, 1, 1};
5742
+
5743
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5744
+ }
5745
+
5746
+ static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5747
+ GGML_ASSERT(src0);
5748
+ GGML_ASSERT(src0->extra);
5749
+ GGML_ASSERT(dst);
5750
+ GGML_ASSERT(dst->extra);
5751
+
5752
+ UNUSED(src1);
5753
+
5754
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5755
+
5756
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5757
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5758
+
5759
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5760
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5761
+
5762
+ cl_kernel kernel;
5763
+
5764
+ int n = ggml_nelements(dst);
5765
+
5766
+ if (n % 4 == 0) {
5767
+ kernel = backend_ctx->kernel_gelu_erf_4;
5768
+ n /= 4;
5769
+ } else {
5770
+ kernel = backend_ctx->kernel_gelu_erf;
5771
+ }
5085
5772
 
5086
5773
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5087
5774
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -5254,6 +5941,36 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
5254
5941
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
5255
5942
  }
5256
5943
 
5944
+ static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5945
+ GGML_ASSERT(dst);
5946
+ GGML_ASSERT(dst->extra);
5947
+
5948
+ UNUSED(src0);
5949
+ UNUSED(src1);
5950
+
5951
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5952
+
5953
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5954
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5955
+
5956
+ float v = 0.0f;
5957
+ memcpy(&v, ((int32_t *) dst->op_params), sizeof(float));
5958
+
5959
+ const int64_t n = ggml_nelements(dst);
5960
+
5961
+ cl_kernel kernel = backend_ctx->kernel_fill;
5962
+
5963
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extrad->data_device));
5964
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsetd));
5965
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(float), &v));
5966
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(float), &n));
5967
+
5968
+ size_t local_work_size[1] = { 256 };
5969
+ size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
5970
+
5971
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
5972
+ }
5973
+
5257
5974
  static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5258
5975
  GGML_ASSERT(src0);
5259
5976
  GGML_ASSERT(src0->extra);
@@ -5530,7 +6247,7 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
5530
6247
  CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
5531
6248
  CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
5532
6249
  CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
5533
- CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
6250
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL));
5534
6251
 
5535
6252
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5536
6253
  }
@@ -5666,10 +6383,148 @@ static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor
5666
6383
  CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
5667
6384
  CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));
5668
6385
 
5669
- backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
6386
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
6387
+ }
6388
+
6389
+ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6390
+ GGML_ASSERT(src0);
6391
+ GGML_ASSERT(src0->extra);
6392
+ GGML_ASSERT(dst);
6393
+ GGML_ASSERT(dst->extra);
6394
+
6395
+ UNUSED(src1);
6396
+
6397
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6398
+
6399
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6400
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6401
+
6402
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
6403
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
6404
+
6405
+ int32_t n_groups = ((const int32_t *) dst->op_params)[0];
6406
+ int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups);
6407
+ float eps = ((const float *) dst->op_params)[1];
6408
+
6409
+ const int ne00 = src0->ne[0];
6410
+ const int ne01 = src0->ne[1];
6411
+ const int ne02 = src0->ne[2];
6412
+ const int ne = ne00*ne01*ne02;
6413
+
6414
+ cl_kernel kernel = backend_ctx->kernel_group_norm;
6415
+
6416
+ size_t sgs = 64;
6417
+ if (backend_ctx->gpu_family == ADRENO) {
6418
+ sgs = 64;
6419
+ } else if (backend_ctx->gpu_family == INTEL) {
6420
+ sgs = 32;
6421
+ } else {
6422
+ GGML_ASSERT(false && "Unsupported GPU");
6423
+ }
6424
+
6425
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6426
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6427
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
6428
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
6429
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne));
6430
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &group_size));
6431
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
6432
+
6433
+ size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1};
6434
+ size_t local_work_size[] = {(size_t)sgs, 1, 1};
6435
+
6436
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6437
+ }
6438
+
6439
+ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6440
+ GGML_ASSERT(src0);
6441
+ GGML_ASSERT(src0->extra);
6442
+ GGML_ASSERT(dst);
6443
+ GGML_ASSERT(dst->extra);
6444
+
6445
+ UNUSED(src1);
6446
+
6447
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6448
+
6449
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6450
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6451
+
6452
+ cl_ulong offset0_abs = extra0->offset + src0->view_offs;
6453
+ cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
6454
+
6455
+ cl_kernel kernel;
6456
+ if (dst->type == GGML_TYPE_F32) {
6457
+ kernel = backend_ctx->kernel_tanh_f32_nd;
6458
+ } else if (dst->type == GGML_TYPE_F16) {
6459
+ kernel = backend_ctx->kernel_tanh_f16_nd;
6460
+ } else {
6461
+ GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh");
6462
+ }
6463
+ GGML_ASSERT(kernel != nullptr);
6464
+
6465
+ const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3];
6466
+ const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3];
6467
+
6468
+ const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3];
6469
+ const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3];
6470
+
6471
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6472
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
6473
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
6474
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
6475
+
6476
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
6477
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
6478
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
6479
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
6480
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
6481
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
6482
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
6483
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
6484
+
6485
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
6486
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
6487
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
6488
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
6489
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
6490
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
6491
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
6492
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
6493
+
6494
+ size_t global_work_size[3];
6495
+ if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
6496
+ return;
6497
+ }
6498
+ global_work_size[0] = (size_t)ne10;
6499
+ global_work_size[1] = (size_t)ne11;
6500
+ global_work_size[2] = (size_t)ne12;
6501
+
6502
+ size_t lws0 = 16, lws1 = 4, lws2 = 1;
6503
+ if (ne10 < 16) lws0 = ne10;
6504
+ if (ne11 < 4) lws1 = ne11;
6505
+ if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
6506
+
6507
+ while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
6508
+ while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
6509
+ while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
6510
+
6511
+
6512
+ size_t local_work_size[] = {lws0, lws1, lws2};
6513
+
6514
+ size_t* local_work_size_ptr = local_work_size;
6515
+ if (!backend_ctx->non_uniform_workgroups) {
6516
+ if (global_work_size[0] % local_work_size[0] != 0 ||
6517
+ global_work_size[1] % local_work_size[1] != 0 ||
6518
+ global_work_size[2] % local_work_size[2] != 0) {
6519
+ local_work_size_ptr = NULL;
6520
+ }
6521
+ }
6522
+ if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
6523
+
6524
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
5670
6525
  }
5671
6526
 
5672
- static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6527
+ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5673
6528
  GGML_ASSERT(src0);
5674
6529
  GGML_ASSERT(src0->extra);
5675
6530
  GGML_ASSERT(dst);
@@ -5682,44 +6537,96 @@ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0,
5682
6537
  ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5683
6538
  ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5684
6539
 
5685
- cl_ulong offset0 = extra0->offset + src0->view_offs;
5686
- cl_ulong offsetd = extrad->offset + dst->view_offs;
6540
+ cl_ulong offset0_abs = extra0->offset + src0->view_offs;
6541
+ cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
5687
6542
 
5688
- int32_t n_groups = ((const int32_t *) dst->op_params)[0];
5689
- int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups);
5690
- float eps = ((const float *) dst->op_params)[1];
6543
+ cl_kernel kernel;
6544
+ if (dst->type == GGML_TYPE_F32) {
6545
+ kernel = backend_ctx->kernel_expm1_f32_nd;
6546
+ } else if (dst->type == GGML_TYPE_F16) {
6547
+ kernel = backend_ctx->kernel_expm1_f16_nd;
6548
+ } else {
6549
+ GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1");
6550
+ }
6551
+ GGML_ASSERT(kernel != nullptr);
5691
6552
 
5692
6553
  const int ne00 = src0->ne[0];
5693
6554
  const int ne01 = src0->ne[1];
5694
6555
  const int ne02 = src0->ne[2];
5695
- const int ne = ne00*ne01*ne02;
6556
+ const int ne03 = src0->ne[3];
5696
6557
 
5697
- cl_kernel kernel = backend_ctx->kernel_group_norm;
6558
+ const cl_ulong nb00 = src0->nb[0];
6559
+ const cl_ulong nb01 = src0->nb[1];
6560
+ const cl_ulong nb02 = src0->nb[2];
6561
+ const cl_ulong nb03 = src0->nb[3];
5698
6562
 
5699
- size_t sgs = 64;
5700
- if (backend_ctx->gpu_family == ADRENO) {
5701
- sgs = 64;
5702
- } else if (backend_ctx->gpu_family == INTEL) {
5703
- sgs = 32;
5704
- } else {
5705
- GGML_ASSERT(false && "Unsupported GPU");
5706
- }
6563
+ const int ne10 = dst->ne[0];
6564
+ const int ne11 = dst->ne[1];
6565
+ const int ne12 = dst->ne[2];
6566
+ const int ne13 = dst->ne[3];
6567
+
6568
+ const cl_ulong nb10 = dst->nb[0];
6569
+ const cl_ulong nb11 = dst->nb[1];
6570
+ const cl_ulong nb12 = dst->nb[2];
6571
+ const cl_ulong nb13 = dst->nb[3];
5707
6572
 
5708
6573
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5709
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6574
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
5710
6575
  CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5711
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5712
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne));
5713
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &group_size));
5714
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
6576
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
5715
6577
 
5716
- size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1};
5717
- size_t local_work_size[] = {(size_t)sgs, 1, 1};
6578
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
6579
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
6580
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
6581
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
6582
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
6583
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
6584
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
6585
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
5718
6586
 
5719
- backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6587
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
6588
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
6589
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
6590
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
6591
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
6592
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
6593
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
6594
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
6595
+
6596
+ size_t global_work_size[3];
6597
+ if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
6598
+ return;
6599
+ }
6600
+ global_work_size[0] = (size_t)ne10;
6601
+ global_work_size[1] = (size_t)ne11;
6602
+ global_work_size[2] = (size_t)ne12;
6603
+
6604
+ size_t lws0 = 16, lws1 = 4, lws2 = 1;
6605
+ if (ne10 < 16) lws0 = ne10;
6606
+ if (ne11 < 4) lws1 = ne11;
6607
+ if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
6608
+
6609
+ while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
6610
+ while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
6611
+ while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
6612
+
6613
+
6614
+ size_t local_work_size[] = {lws0, lws1, lws2};
6615
+
6616
+ size_t* local_work_size_ptr = local_work_size;
6617
+ if (!backend_ctx->non_uniform_workgroups) {
6618
+ if (global_work_size[0] % local_work_size[0] != 0 ||
6619
+ global_work_size[1] % local_work_size[1] != 0 ||
6620
+ global_work_size[2] % local_work_size[2] != 0) {
6621
+ local_work_size_ptr = NULL;
6622
+ }
6623
+ }
6624
+ if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
6625
+
6626
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
5720
6627
  }
5721
6628
 
5722
- static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6629
+ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5723
6630
  GGML_ASSERT(src0);
5724
6631
  GGML_ASSERT(src0->extra);
5725
6632
  GGML_ASSERT(dst);
@@ -5737,19 +6644,33 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const
5737
6644
 
5738
6645
  cl_kernel kernel;
5739
6646
  if (dst->type == GGML_TYPE_F32) {
5740
- kernel = backend_ctx->kernel_tanh_f32_nd;
6647
+ kernel = backend_ctx->kernel_softplus_f32_nd;
5741
6648
  } else if (dst->type == GGML_TYPE_F16) {
5742
- kernel = backend_ctx->kernel_tanh_f16_nd;
6649
+ kernel = backend_ctx->kernel_softplus_f16_nd;
5743
6650
  } else {
5744
- GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh");
6651
+ GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
5745
6652
  }
5746
6653
  GGML_ASSERT(kernel != nullptr);
5747
6654
 
5748
- const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3];
5749
- const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3];
6655
+ const int ne00 = src0->ne[0];
6656
+ const int ne01 = src0->ne[1];
6657
+ const int ne02 = src0->ne[2];
6658
+ const int ne03 = src0->ne[3];
5750
6659
 
5751
- const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3];
5752
- const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3];
6660
+ const cl_ulong nb00 = src0->nb[0];
6661
+ const cl_ulong nb01 = src0->nb[1];
6662
+ const cl_ulong nb02 = src0->nb[2];
6663
+ const cl_ulong nb03 = src0->nb[3];
6664
+
6665
+ const int ne10 = dst->ne[0];
6666
+ const int ne11 = dst->ne[1];
6667
+ const int ne12 = dst->ne[2];
6668
+ const int ne13 = dst->ne[3];
6669
+
6670
+ const cl_ulong nb10 = dst->nb[0];
6671
+ const cl_ulong nb11 = dst->nb[1];
6672
+ const cl_ulong nb12 = dst->nb[2];
6673
+ const cl_ulong nb13 = dst->nb[3];
5753
6674
 
5754
6675
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5755
6676
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
@@ -5874,7 +6795,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
5874
6795
  GGML_ASSERT(dst->extra);
5875
6796
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
5876
6797
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
5877
- GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);
5878
6798
 
5879
6799
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5880
6800
 
@@ -5892,28 +6812,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
5892
6812
  const int s_ne0 = src0->ne[0];
5893
6813
  const int s_ne1 = src0->ne[1];
5894
6814
  const int s_ne2 = src0->ne[2];
6815
+ const int s_ne3 = src0->ne[3];
6816
+
6817
+ const int s_nb0 = src0->nb[0];
6818
+ const int s_nb1 = src0->nb[1];
6819
+ const int s_nb2 = src0->nb[2];
6820
+ const int s_nb3 = src0->nb[3];
5895
6821
 
5896
6822
  const int d_ne0 = dst->ne[0];
5897
6823
  const int d_ne1 = dst->ne[1];
5898
6824
  const int d_ne2 = dst->ne[2];
6825
+ const int d_ne3 = dst->ne[3];
6826
+
6827
+ const int d_nb0 = dst->nb[0];
6828
+ const int d_nb1 = dst->nb[1];
6829
+ const int d_nb2 = dst->nb[2];
6830
+ const int d_nb3 = dst->nb[3];
6831
+
6832
+ const int lp0 = ((const int*)(dst->op_params))[0];
6833
+ const int rp0 = ((const int*)(dst->op_params))[1];
6834
+ const int lp1 = ((const int*)(dst->op_params))[2];
6835
+ const int rp1 = ((const int*)(dst->op_params))[3];
6836
+ const int lp2 = ((const int*)(dst->op_params))[4];
6837
+ const int rp2 = ((const int*)(dst->op_params))[5];
6838
+ const int lp3 = ((const int*)(dst->op_params))[6];
6839
+ const int rp3 = ((const int*)(dst->op_params))[7];
5899
6840
 
5900
6841
  cl_kernel kernel = backend_ctx->kernel_pad;
5901
6842
 
5902
- CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
5903
- CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
5904
- CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
5905
- CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
5906
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
5907
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
5908
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
5909
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0));
5910
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1));
5911
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2));
6843
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
6844
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
6845
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
6846
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
6847
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
6848
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
6849
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
6850
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3));
6851
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0));
6852
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1));
6853
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2));
6854
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3));
6855
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
6856
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
6857
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
6858
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3));
6859
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0));
6860
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1));
6861
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2));
6862
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3));
6863
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0));
6864
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0));
6865
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1));
6866
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1));
6867
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2));
6868
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2));
6869
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3));
6870
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3));
5912
6871
 
5913
6872
  size_t lws0 = 64;
5914
6873
  size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
5915
6874
 
5916
- size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 };
6875
+ size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };
5917
6876
  size_t local_work_size[] = { lws0, 1, 1 };
5918
6877
 
5919
6878
  size_t * local_work_size_ptr = local_work_size;
@@ -6003,8 +6962,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
6003
6962
  CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
6004
6963
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
6005
6964
  if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
6006
- sf0 = (float)(ne0 - 1) / (ne00 - 1);
6007
- sf1 = (float)(ne1 - 1) / (ne01 - 1);
6965
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
6966
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
6008
6967
  pixel_offset = 0.0f;
6009
6968
  }
6010
6969
 
@@ -6475,6 +7434,146 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
6475
7434
  backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
6476
7435
  }
6477
7436
 
7437
+ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7438
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
7439
+
7440
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
7441
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
7442
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
7443
+
7444
+ const int ne00 = src0->ne[0];
7445
+ const int ne01 = src0->ne[1];
7446
+ const int ne02 = src0->ne[2];
7447
+
7448
+ const cl_ulong nb01 = src0->nb[1];
7449
+ const cl_ulong nb02 = src0->nb[2];
7450
+
7451
+ const int ne10 = src1->ne[0];
7452
+ const int ne11 = src1->ne[1];
7453
+ const int ne12 = src1->ne[2];
7454
+
7455
+ const cl_ulong nb10 = src1->nb[0];
7456
+
7457
+ const int ne0 = dst->ne[0];
7458
+ const int ne1 = dst->ne[1];
7459
+
7460
+ GGML_ASSERT(ne00 == ne10);
7461
+
7462
+ cl_kernel kernel;
7463
+ cl_context context = backend_ctx->context;
7464
+
7465
+ cl_int status;
7466
+ cl_image_format img_fmt_1d;
7467
+ cl_image_desc img_desc_1d;
7468
+ cl_buffer_region region;
7469
+ cl_mem A_image1d;
7470
+ cl_mem A_sub_buffer;
7471
+ cl_mem B_sub_buffer;
7472
+ cl_mem D_image1d;
7473
+ cl_mem D_sub_buffer;
7474
+
7475
+ int M = ne01;
7476
+ int N = ne1;
7477
+ int K = ne00;
7478
+
7479
+ if (nb01 > nb02) {
7480
+ // KQ
7481
+ kernel = backend_ctx->kernel_mul_mm_f16_f32_kq;
7482
+ } else {
7483
+ // KQV
7484
+ kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv;
7485
+ }
7486
+ // create sub-buffer for A
7487
+ // <--------------------------------------------> //
7488
+ extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra;
7489
+
7490
+ region.origin = (extra0->offset);
7491
+ if (nb01 > nb02) {
7492
+ // KQ
7493
+ region.size = nb01 * ne01;
7494
+ } else {
7495
+ // KQV
7496
+ region.size = nb02 * ne02;
7497
+ }
7498
+
7499
+ A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
7500
+ CL_CHECK(status);
7501
+
7502
+ // <--------------------------------------------> //
7503
+
7504
+ // create sub-buffer for B
7505
+ // <--------------------------------------------> //
7506
+ region.origin = (extra1->offset);
7507
+ region.size = nb10 * ne10 * ne11 * ne12;
7508
+ B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
7509
+ CL_CHECK(status);
7510
+ // <--------------------------------------------> //
7511
+
7512
+ img_fmt_1d = {CL_RGBA, CL_FLOAT};
7513
+ memset(&img_desc_1d, 0, sizeof(img_desc_1d));
7514
+ img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
7515
+ if (nb01 > nb02) {
7516
+ img_desc_1d.image_width = (nb01 * ne01 / 4)/4;
7517
+ }
7518
+ else {
7519
+ img_desc_1d.image_width = (nb02 * ne02 / 4)/4;
7520
+ }
7521
+ img_desc_1d.buffer = A_sub_buffer;
7522
+ A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
7523
+ CL_CHECK(status);
7524
+
7525
+ // create sub-buffer for output C
7526
+ // <--------------------------------------------> //
7527
+ region.origin = (extrad->offset);
7528
+ region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes
7529
+ D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
7530
+ CL_CHECK(status);
7531
+ // <--------------------------------------------> //
7532
+
7533
+ // create image for C output
7534
+ // <--------------------------------------------> //
7535
+ img_fmt_1d = {CL_R, CL_FLOAT};
7536
+ memset(&img_desc_1d, 0, sizeof(img_desc_1d));
7537
+ img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
7538
+ img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4;
7539
+ img_desc_1d.buffer = D_sub_buffer;
7540
+ D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
7541
+ CL_CHECK(status);
7542
+ // <--------------------------------------------> //
7543
+
7544
+ int offset_src0 = 0;
7545
+ int offset_src1 = 0;
7546
+
7547
+ // set kernel args
7548
+ // <--------------------------------------------> //
7549
+ cl_uint k_arg = 0;
7550
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d));
7551
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0));
7552
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer));
7553
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1));
7554
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d));
7555
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset));
7556
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M));
7557
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K));
7558
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N));
7559
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02));
7560
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12));
7561
+ CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01));
7562
+
7563
+ size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)};
7564
+ size_t local_work_size[3] = {64, 1, 2};
7565
+
7566
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
7567
+
7568
+ // deallocate sub buffers and images
7569
+ // <--------------------------------------------> //
7570
+ CL_CHECK(clReleaseMemObject(A_image1d));
7571
+ CL_CHECK(clReleaseMemObject(D_image1d));
7572
+ CL_CHECK(clReleaseMemObject(A_sub_buffer));
7573
+ CL_CHECK(clReleaseMemObject(B_sub_buffer));
7574
+ CL_CHECK(clReleaseMemObject(D_sub_buffer));
7575
+ }
7576
+
6478
7577
  static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6479
7578
  GGML_ASSERT(src0);
6480
7579
  GGML_ASSERT(src0->extra);
@@ -6541,6 +7640,27 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
6541
7640
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
6542
7641
  cl_context context = backend_ctx->context;
6543
7642
 
7643
+ if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
7644
+ if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
7645
+ // For KQ
7646
+ if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
7647
+ nb00 <= nb02 &&
7648
+ nb02 <= nb01 &&
7649
+ nb01 <= nb03 &&
7650
+ nb10 <= nb12 &&
7651
+ nb12 <= nb11 &&
7652
+ nb11 <= nb13) {
7653
+ ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
7654
+ return;
7655
+ }
7656
+ // For KQV
7657
+ if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
7658
+ ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
7659
+ return;
7660
+ }
7661
+ }
7662
+ }
7663
+
6544
7664
  if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) {
6545
7665
 
6546
7666
  // init CL objects
@@ -6620,8 +7740,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
6620
7740
  region.origin = 0;
6621
7741
  // Specify the size of the sub-buffer (divide by 2 for FP16)
6622
7742
  region.size = K * (N + padding) * sizeof(float)/2;
7743
+ backend_ctx->prealloc_act_trans.allocate(context, region.size);
7744
+
6623
7745
  B_d = clCreateSubBuffer(
6624
- backend_ctx->B_d_max,
7746
+ backend_ctx->prealloc_act_trans.buffer,
6625
7747
  0,
6626
7748
  CL_BUFFER_CREATE_TYPE_REGION,
6627
7749
  &region,
@@ -6914,6 +8036,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
6914
8036
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6915
8037
  return;
6916
8038
  }
8039
+ case GGML_TYPE_Q8_0: {
8040
+ if (ne11 < 32) {
8041
+ break;
8042
+ }
8043
+ kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
8044
+ nth0 = 128; // calculated as (BM*BN)/(TM*TN)
8045
+
8046
+ int batch_stride_a = ne00*ne01;
8047
+ int batch_stride_b = ne10*ne11;
8048
+ int batch_stride_d = ne0*ne1;
8049
+
8050
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
8051
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
8052
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
8053
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
8054
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
8055
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
8056
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
8057
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
8058
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
8059
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
8060
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
8061
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
8062
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
8063
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
8064
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
8065
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
8066
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
8067
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
8068
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
8069
+
8070
+ // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
8071
+ size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
8072
+ size_t local_work_size[] = {(size_t)nth0, 1, 1};
8073
+
8074
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
8075
+ return;
8076
+ }
6917
8077
  default:
6918
8078
  break;
6919
8079
  }
@@ -7450,6 +8610,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
7450
8610
  const int ne21 = src2->ne[1];
7451
8611
 
7452
8612
  const cl_ulong nb21 = src2->nb[1];
8613
+ const cl_ulong nb20 = src2->nb[0];
8614
+
8615
+ UNUSED(nb20);
7453
8616
 
7454
8617
  const int ne0 = dst->ne[0];
7455
8618
  const int ne1 = dst->ne[1];
@@ -7589,6 +8752,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
7589
8752
  break;
7590
8753
  }
7591
8754
  case GGML_TYPE_MXFP4: {
8755
+ #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
8756
+ if (use_adreno_moe_kernels(backend_ctx, src0)) {
8757
+ cl_int status;
8758
+
8759
+ size_t local_size[3] = {64, 2, 1};
8760
+ size_t global_size[3] = {64, 2, 1};
8761
+
8762
+ cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
8763
+
8764
+ int tile_size = 320;
8765
+ if (ne12 == 1) { // for gemv
8766
+ kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
8767
+
8768
+ // create a sub_buffer for src2
8769
+ cl_buffer_region region;
8770
+ region.origin = offset2;
8771
+ region.size = ne20 * ne21 * sizeof(int);
8772
+ buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
8773
+ CL_CHECK(status);
8774
+
8775
+ // set thread grid
8776
+ global_size[0] = static_cast<size_t>(ne01);
8777
+ global_size[1] = 4;
8778
+ global_size[2] = static_cast<size_t>(ne20);
8779
+ local_size[1] = 4;
8780
+ } else { // for gemm
8781
+ kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
8782
+
8783
+ // preprocess router table
8784
+ int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
8785
+ void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
8786
+ void * host_src2 = malloc(ne21 * nb21);
8787
+ CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
8788
+ int total_experts = nb21 / nb20;
8789
+ int out_idx = 0;
8790
+ for (int i_expert = 0; i_expert < ne02; i_expert++) {
8791
+ for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
8792
+ for (int j = 0; j < ne21; j++) {
8793
+ for (int i = 0; i < ne20; i++) {
8794
+ int expert = ((int *)host_src2)[j * total_experts + i];
8795
+ if (i_expert == expert) {
8796
+ ((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
8797
+ ((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
8798
+ ((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
8799
+ ((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
8800
+ out_idx += 4;
8801
+ }
8802
+ }
8803
+ }
8804
+ }
8805
+ }
8806
+ buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
8807
+ CL_CHECK(status);
8808
+
8809
+ // set thread grid
8810
+ global_size[0] = static_cast<size_t>(tile_size);
8811
+ global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
8812
+ }
8813
+
8814
+ // create a sub_buffer for src1
8815
+ cl_buffer_region region;
8816
+ region.origin = offset1;
8817
+ region.size = ne10 * ne11 * ne12 * sizeof(float);
8818
+ src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
8819
+ CL_CHECK(status);
8820
+
8821
+ // create image for src1
8822
+ cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
8823
+ cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
8824
+ buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
8825
+ CL_CHECK(status);
8826
+
8827
+ // Set kernel args
8828
+ int arg_idx = 0;
8829
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
8830
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
8831
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
8832
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
8833
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
8834
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
8835
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
8836
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
8837
+ if (ne12 == 1) {
8838
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
8839
+ } else {
8840
+ CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
8841
+ }
8842
+
8843
+ // launch kernel
8844
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
8845
+
8846
+ // deallocate sub buffers and images
8847
+ CL_CHECK(clReleaseMemObject(src1_sub_buffer));
8848
+ CL_CHECK(clReleaseMemObject(buf_src1_image));
8849
+ CL_CHECK(clReleaseMemObject(buf_src2));
8850
+ return;
8851
+ } // else fallback to generic kernel
8852
+ #endif // GGML_OPENCL_USE_ADRENO_KERNELS
8853
+
7592
8854
  #ifdef GGML_OPENCL_SOA_Q
7593
8855
  kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
7594
8856
 
@@ -8106,6 +9368,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
8106
9368
  const bool is_neox = mode & 2;
8107
9369
  const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
8108
9370
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
9371
+ const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
8109
9372
 
8110
9373
  if (is_mrope) {
8111
9374
  GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
@@ -8196,9 +9459,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
8196
9459
  CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
8197
9460
  CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
8198
9461
  CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
9462
+ // both mrope and vision kernels have sections
8199
9463
  if (is_mrope || is_vision) {
8200
9464
  CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));
8201
9465
  }
9466
+ // only mrope has is_imrope
9467
+ if (is_mrope && !is_vision) {
9468
+ CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
9469
+ }
8202
9470
 
8203
9471
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
8204
9472
  size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -8571,6 +9839,24 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
8571
9839
  }
8572
9840
  func = ggml_cl_sub;
8573
9841
  break;
9842
+ case GGML_OP_SQR:
9843
+ if (!any_on_device) {
9844
+ return false;
9845
+ }
9846
+ func = ggml_cl_sqr;
9847
+ break;
9848
+ case GGML_OP_SQRT:
9849
+ if (!any_on_device) {
9850
+ return false;
9851
+ }
9852
+ func = ggml_cl_sqrt;
9853
+ break;
9854
+ case GGML_OP_MEAN:
9855
+ if (!any_on_device) {
9856
+ return false;
9857
+ }
9858
+ func = ggml_cl_mean;
9859
+ break;
8574
9860
  case GGML_OP_UNARY:
8575
9861
  switch (ggml_get_unary_op(tensor)) {
8576
9862
  case GGML_UNARY_OP_GELU:
@@ -8615,6 +9901,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
8615
9901
  }
8616
9902
  func = ggml_cl_tanh;
8617
9903
  break;
9904
+ case GGML_UNARY_OP_EXPM1:
9905
+ if (!any_on_device) {
9906
+ return false;
9907
+ }
9908
+ func = ggml_cl_expm1;
9909
+ break;
9910
+ case GGML_UNARY_OP_SOFTPLUS:
9911
+ if (!any_on_device) {
9912
+ return false;
9913
+ }
9914
+ func = ggml_cl_softplus;
9915
+ break;
8618
9916
  default:
8619
9917
  return false;
8620
9918
  } break;
@@ -8624,6 +9922,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
8624
9922
  }
8625
9923
  func = ggml_cl_glu;
8626
9924
  break;
9925
+ case GGML_OP_FILL:
9926
+ if (!any_on_device) {
9927
+ return false;
9928
+ }
9929
+ func = ggml_cl_fill;
9930
+ break;
8627
9931
  case GGML_OP_CLAMP:
8628
9932
  if (!any_on_device) {
8629
9933
  return false;
@@ -8672,6 +9976,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
8672
9976
  }
8673
9977
  func = ggml_cl_conv_2d;
8674
9978
  break;
9979
+ case GGML_OP_SSM_CONV:
9980
+ if (!any_on_device) {
9981
+ return false;
9982
+ }
9983
+ func = ggml_cl_ssm_conv;
9984
+ break;
8675
9985
  case GGML_OP_CONCAT:
8676
9986
  if (!any_on_device) {
8677
9987
  return false;