whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -30,21 +30,29 @@
30
30
  #include <regex>
31
31
 
32
32
  #include <sycl/sycl.hpp>
33
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
34
+ # include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
35
+ #endif
33
36
  #include <sycl/half_type.hpp>
34
37
 
35
38
  #include "ggml-sycl.h"
36
39
  #include "ggml-impl.h"
37
40
  #include "ggml-backend-impl.h"
38
41
 
42
+ #include "ggml-sycl/add-id.hpp"
39
43
  #include "ggml-sycl/backend.hpp"
40
44
  #include "ggml-sycl/common.hpp"
41
45
  #include "ggml-sycl/element_wise.hpp"
46
+ #include "ggml-sycl/norm.hpp"
42
47
  #include "ggml-sycl/presets.hpp"
43
48
  #include "ggml-sycl/gemm.hpp"
44
49
  #include "ggml-sycl/set_rows.hpp"
50
+ #include "ggml-sycl/set.hpp"
45
51
  #include "ggml-sycl/sycl_hw.hpp"
46
52
  #include "ggml-sycl/getrows.hpp"
53
+ #include "ggml-sycl/repeat_back.hpp"
47
54
  #include "ggml-sycl/quantize.hpp"
55
+ #include "ggml-sycl/ssm_conv.hpp"
48
56
  #include "ggml.h"
49
57
 
50
58
  static bool g_sycl_loaded = false;
@@ -53,6 +61,7 @@ int g_ggml_sycl_disable_optimize = 0;
53
61
  int g_ggml_sycl_disable_graph = 0;
54
62
  int g_ggml_sycl_disable_dnn = 0;
55
63
  int g_ggml_sycl_prioritize_dmmv = 0;
64
+ int g_ggml_sycl_use_async_mem_op = 0;
56
65
 
57
66
  static ggml_sycl_device_info ggml_sycl_init() {
58
67
  ggml_sycl_device_info info = {};
@@ -85,7 +94,10 @@ static ggml_sycl_device_info ggml_sycl_init() {
85
94
 
86
95
  info.devices[i].cc =
87
96
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
97
+ info.devices[i].nsm = prop.get_max_compute_units();
88
98
  info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
99
+ info.devices[i].smpbo = prop.get_local_mem_size();
100
+
89
101
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
90
102
  }
91
103
 
@@ -233,7 +245,20 @@ static void ggml_check_sycl() try {
233
245
  fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
234
246
  #endif
235
247
  */
236
-
248
+ // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
249
+ // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
250
+ // other places.
251
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
252
+ g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
253
+ if (g_ggml_sycl_use_async_mem_op) {
254
+ for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
255
+ if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
256
+ g_ggml_sycl_use_async_mem_op = 0;
257
+ break;
258
+ }
259
+ }
260
+ }
261
+ #endif
237
262
  if (CHECK_TRY_ERROR(g_all_sycl_device_count =
238
263
  dpct::dev_mgr::instance().device_count()) != 0) {
239
264
  initialized = true;
@@ -1511,60 +1536,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
1511
1536
  template <ggml_sort_order order>
1512
1537
  __dpct_inline__ static void
1513
1538
  k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
1514
- const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
1539
+ const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
1540
+ uint8_t *dpct_local) {
1515
1541
  // bitonic sort
1516
- int col = item_ct1.get_local_id(2);
1542
+ int col_index = item_ct1.get_local_id(2);
1517
1543
  int row = item_ct1.get_group(1);
1518
1544
 
1519
- if (col >= ncols_pad) {
1520
- return;
1545
+ for (int i = 0; i < tasks_per_thread; i++) {
1546
+ int col = col_index * tasks_per_thread + i;
1547
+ if (col >= ncols_pad) {
1548
+ return;
1549
+ }
1521
1550
  }
1522
1551
 
1523
1552
  const float * x_row = x + row * ncols;
1524
1553
  auto dst_row = (int *)dpct_local;
1525
1554
 
1526
1555
  // initialize indices
1527
- dst_row[col] = col;
1556
+ for (int i=0;i<tasks_per_thread;i++){
1557
+ int col = col_index*tasks_per_thread+i;
1558
+ dst_row[col] = col;
1559
+ }
1528
1560
 
1529
1561
  item_ct1.barrier(sycl::access::fence_space::local_space);
1530
1562
 
1531
1563
  for (int k = 2; k <= ncols_pad; k *= 2) {
1532
1564
  for (int j = k / 2; j > 0; j /= 2) {
1533
- int ixj = col ^ j;
1534
- if (ixj > col) {
1535
- if ((col & k) == 0) {
1536
- if (dst_row[col] >= ncols ||
1537
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
1538
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
1539
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
1540
- ) {
1541
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1542
- }
1543
- } else {
1544
- if (dst_row[ixj] >= ncols ||
1545
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
1546
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
1547
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
1548
- ) {
1549
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1565
+ for (int i = 0; i < tasks_per_thread; i++) {
1566
+ int col = col_index * tasks_per_thread + i;
1567
+ int ixj = col ^ j;
1568
+ if (ixj > col) {
1569
+ if ((col & k) == 0) {
1570
+ if (dst_row[col] >= ncols ||
1571
+ (dst_row[ixj] < ncols &&
1572
+ (order == GGML_SORT_ORDER_ASC
1573
+ ? x_row[dst_row[col]] > x_row[dst_row[ixj]]
1574
+ : x_row[dst_row[col]] <
1575
+ x_row[dst_row[ixj]]))) {
1576
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1577
+ }
1578
+ } else {
1579
+ if (dst_row[ixj] >= ncols ||
1580
+ (dst_row[col] < ncols &&
1581
+ (order == GGML_SORT_ORDER_ASC
1582
+ ? x_row[dst_row[col]] < x_row[dst_row[ixj]]
1583
+ : x_row[dst_row[col]] >
1584
+ x_row[dst_row[ixj]]))) {
1585
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1586
+ }
1550
1587
  }
1551
1588
  }
1589
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1552
1590
  }
1553
- /*
1554
- DPCT1118:1: SYCL group functions and algorithms must be encountered
1555
- in converged control flow. You may need to adjust the code.
1556
- */
1557
- item_ct1.barrier(sycl::access::fence_space::local_space);
1558
1591
  }
1559
1592
  }
1560
1593
 
1561
1594
  // copy the result to dst without the padding
1562
- if (col < ncols) {
1563
- dst[row * ncols + col] = dst_row[col];
1595
+ for (int i = 0; i < tasks_per_thread; i++) {
1596
+ int col = col_index * tasks_per_thread + i;
1597
+ if (col < ncols) {
1598
+ dst[row * ncols + col] = dst_row[col];
1599
+ }
1564
1600
  }
1565
1601
  }
1566
1602
 
1567
-
1568
1603
  static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
1569
1604
  const sycl::nd_item<3> &item_ct1) {
1570
1605
  const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -1737,13 +1772,23 @@ static int next_power_of_2(int x) {
1737
1772
 
1738
1773
  static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1739
1774
  const int nrows, ggml_sort_order order,
1740
- queue_ptr stream) {
1775
+ queue_ptr stream, int device) {
1741
1776
  // bitonic sort requires ncols to be power of 2
1742
1777
  const int ncols_pad = next_power_of_2(ncols);
1743
1778
 
1744
- const sycl::range<3> block_dims(1, 1, ncols_pad);
1779
+ int nth = 1;
1780
+ int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
1781
+ while (nth < ncols_pad && nth < max_block_size)
1782
+ nth *= 2;
1783
+ if (nth > max_block_size)
1784
+ nth = max_block_size;
1785
+
1786
+ const int tasks_per_thread = ncols_pad / nth;
1787
+
1788
+ const sycl::range<3> block_dims(1, 1, nth);
1745
1789
  const sycl::range<3> block_nums(1, nrows, 1);
1746
1790
  const size_t shared_mem = ncols_pad * sizeof(int);
1791
+ GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
1747
1792
 
1748
1793
  if (order == GGML_SORT_ORDER_ASC) {
1749
1794
  stream->submit([&](sycl::handler &cgh) {
@@ -1754,8 +1799,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1754
1799
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1755
1800
  [=](sycl::nd_item<3> item_ct1) {
1756
1801
  k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
1757
- x, dst, ncols, ncols_pad, item_ct1,
1758
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
1802
+ x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1803
+ dpct_local_acc_ct1
1804
+ .get_multi_ptr<sycl::access::decorated::no>()
1759
1805
  .get());
1760
1806
  });
1761
1807
  });
@@ -1768,8 +1814,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1768
1814
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1769
1815
  [=](sycl::nd_item<3> item_ct1) {
1770
1816
  k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
1771
- x, dst, ncols, ncols_pad, item_ct1,
1772
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
1817
+ x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1818
+ dpct_local_acc_ct1
1819
+ .get_multi_ptr<sycl::access::decorated::no>()
1773
1820
  .get());
1774
1821
  });
1775
1822
  });
@@ -2127,6 +2174,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
2127
2174
  sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2128
2175
  }
2129
2176
 
2177
+ inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2178
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2179
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2180
+
2181
+ dpct::queue_ptr main_stream = ctx.stream();
2182
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2183
+
2184
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2185
+ float * dst_dd = static_cast<float *>(dst->data);
2186
+
2187
+ const int64_t ncols = dst->src[0]->ne[0];
2188
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2189
+
2190
+ sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2191
+
2192
+ main_stream->parallel_for(
2193
+ sycl::range<1>(nrows),
2194
+ [=](sycl::id<1> row) {
2195
+ dst_dd[row] /= ncols;
2196
+ }
2197
+ );
2198
+ }
2199
+
2200
+
2130
2201
  inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2131
2202
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2132
2203
  GGML_ASSERT(dst->type == GGML_TYPE_I32);
@@ -2141,7 +2212,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
2141
2212
 
2142
2213
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2143
2214
 
2144
- argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
2215
+ argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
2216
+ main_stream, ctx.device);
2145
2217
  }
2146
2218
 
2147
2219
  inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -2548,6 +2620,10 @@ catch (sycl::exception const &exc) {
2548
2620
  std::exit(1);
2549
2621
  }
2550
2622
 
2623
+ static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2624
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2625
+ ggml_sycl_op_repeat_back(ctx, dst);
2626
+ }
2551
2627
 
2552
2628
  static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2553
2629
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -2564,6 +2640,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
2564
2640
  ggml_sycl_op_rms_norm(ctx, dst);
2565
2641
  }
2566
2642
 
2643
+ static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2644
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2645
+ ggml_sycl_op_rms_norm_back(ctx, dst);
2646
+ }
2647
+
2567
2648
  static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2568
2649
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2569
2650
  ggml_sycl_op_l2_norm(ctx, dst);
@@ -2981,19 +3062,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2981
3062
  }
2982
3063
  }
2983
3064
 
3065
+ // Helper functions to unify device memory allocation for both async and sync paths
3066
+ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
3067
+ bool use_async = g_ggml_sycl_use_async_mem_op;
3068
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3069
+ if (use_async) {
3070
+ return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
3071
+ }
3072
+ #else
3073
+ // If async allocation extension is not available, use_async should always be false.
3074
+ GGML_ASSERT(!use_async);
3075
+ #endif
3076
+ return sycl::malloc(size, *stream, sycl::usm::alloc::device);
3077
+ }
3078
+
3079
+ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
3080
+ bool use_async = g_ggml_sycl_use_async_mem_op;
3081
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3082
+ if (use_async) {
3083
+ syclex::async_free(*stream, ptr);
3084
+ return;
3085
+ }
3086
+ #else
3087
+ // If async allocation extension is not available, use_async should always be false.
3088
+ GGML_ASSERT(!use_async);
3089
+ #endif
3090
+ sycl::free(ptr, *stream);
3091
+ }
3092
+
2984
3093
  static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
2985
3094
  dpct::queue_ptr stream) {
2986
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2987
- SYCL_CHECK(
2988
- CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2989
- .wait()));
3095
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3096
+
3097
+ sycl::event copy_event;
3098
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3099
+ if (!g_ggml_sycl_use_async_mem_op) {
3100
+ copy_event.wait();
3101
+ }
3102
+
2990
3103
  GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2991
3104
  GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2992
3105
  int offset_blks = offset / sizeof(block_q4_0);
2993
3106
  auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
2994
3107
  auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2995
3108
 
2996
- stream->parallel_for(
3109
+ auto reorder_event = stream->parallel_for(
2997
3110
  size / sizeof(block_q4_0),
2998
3111
  [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2999
3112
  const block_q4_0* x = (const block_q4_0*)tmp_buf;
@@ -3004,9 +3117,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
3004
3117
  *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3005
3118
  }
3006
3119
  *(d_ptr + ib) = x[ib].d;
3007
- }).wait_and_throw();
3008
-
3009
- sycl::free(tmp_buf, *stream);
3120
+ });
3121
+ if (!g_ggml_sycl_use_async_mem_op) {
3122
+ reorder_event.wait_and_throw();
3123
+ }
3124
+ sycl_ext_free(stream, tmp_buf);
3010
3125
  }
3011
3126
 
3012
3127
  static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3015,14 +3130,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3015
3130
 
3016
3131
  const int nblocks = size / sizeof(block_q4_K);
3017
3132
 
3018
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3019
- SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3133
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3134
+
3135
+ sycl::event copy_event;
3136
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3137
+ if (!g_ggml_sycl_use_async_mem_op) {
3138
+ copy_event.wait();
3139
+ }
3020
3140
 
3021
3141
  auto * qs_ptr = data_device;
3022
3142
  auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3023
3143
  auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3024
3144
 
3025
- stream->parallel_for(nblocks, [=](auto i) {
3145
+ auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3026
3146
  const block_q4_K * x = (const block_q4_K *) tmp_buf;
3027
3147
  const int ib = i;
3028
3148
 
@@ -3035,9 +3155,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3035
3155
  }
3036
3156
 
3037
3157
  dm_ptr[ib] = x[ib].dm;
3038
- }).wait_and_throw();
3039
-
3040
- sycl::free(tmp_buf, *stream);
3158
+ });
3159
+ if (!g_ggml_sycl_use_async_mem_op) {
3160
+ reorder_event.wait_and_throw();
3161
+ }
3162
+ sycl_ext_free(stream, tmp_buf);
3041
3163
  }
3042
3164
 
3043
3165
  static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3046,42 +3168,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
3046
3168
 
3047
3169
  const int nblocks = size / sizeof(block_q6_K);
3048
3170
 
3049
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3050
- SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3171
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3172
+
3173
+ sycl::event copy_event;
3174
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3175
+ if (!g_ggml_sycl_use_async_mem_op) {
3176
+ copy_event.wait();
3177
+ }
3051
3178
 
3052
3179
  auto * ql_ptr = data_device;
3053
3180
  auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3054
3181
  auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3055
3182
  sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3056
3183
 
3057
- stream
3058
- ->parallel_for(nblocks,
3059
- [=](auto i) {
3060
- const block_q6_K * x = (const block_q6_K *) tmp_buf;
3061
- const int ib = i;
3062
-
3063
- const uint8_t * ql = x[ib].ql;
3064
- const uint8_t * qh = x[ib].qh;
3065
- uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3066
- uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3067
- uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3184
+ auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3185
+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3186
+ const int ib = i;
3068
3187
 
3069
- for (int j = 0; j < QK_K / 2; ++j) {
3070
- base_ql_ptr[j] = ql[j];
3071
- }
3072
- for (int j = 0; j < QK_K / 4; ++j) {
3073
- base_qh_ptr[j] = qh[j];
3074
- }
3188
+ const uint8_t * ql = x[ib].ql;
3189
+ const uint8_t * qh = x[ib].qh;
3190
+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3191
+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3192
+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3075
3193
 
3076
- for (int j = 0; j < QK_K / 16; ++j) {
3077
- base_scales_ptr[j] = x[ib].scales[j];
3078
- }
3194
+ for (int j = 0; j < QK_K / 2; ++j) {
3195
+ base_ql_ptr[j] = ql[j];
3196
+ }
3197
+ for (int j = 0; j < QK_K / 4; ++j) {
3198
+ base_qh_ptr[j] = qh[j];
3199
+ }
3079
3200
 
3080
- dm_ptr[ib] = x[ib].d;
3081
- })
3082
- .wait_and_throw();
3201
+ for (int j = 0; j < QK_K / 16; ++j) {
3202
+ base_scales_ptr[j] = x[ib].scales[j];
3203
+ }
3083
3204
 
3084
- sycl::free(tmp_buf, *stream);
3205
+ dm_ptr[ib] = x[ib].d;
3206
+ });
3207
+ if (!g_ggml_sycl_use_async_mem_op) {
3208
+ reorder_event.wait_and_throw();
3209
+ }
3210
+ sycl_ext_free(stream, tmp_buf);
3085
3211
  }
3086
3212
 
3087
3213
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
@@ -3188,6 +3314,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3188
3314
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3189
3315
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
3190
3316
 
3317
+
3191
3318
  // mmvq and mmq need the __dp4a instruction which is available for gen12+
3192
3319
  // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
3193
3320
  use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
@@ -3195,7 +3322,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3195
3322
  use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
3196
3323
  #endif // SYCL_USE_XMX
3197
3324
 
3198
-
3199
3325
  // mmvq path is faster in the CUDA backend.
3200
3326
  if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
3201
3327
  // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
@@ -3510,6 +3636,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
3510
3636
  ggml_sycl_op_sum_rows(ctx, dst);
3511
3637
  }
3512
3638
 
3639
+ static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3640
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3641
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3642
+ ggml_sycl_op_mean(ctx, dst);
3643
+ }
3644
+
3513
3645
  static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3514
3646
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3515
3647
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -3561,9 +3693,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3561
3693
  case GGML_OP_REPEAT:
3562
3694
  ggml_sycl_repeat(ctx, dst);
3563
3695
  break;
3696
+ case GGML_OP_REPEAT_BACK:
3697
+ ggml_sycl_repeat_back(ctx, dst);
3698
+ break;
3564
3699
  case GGML_OP_GET_ROWS:
3565
3700
  ggml_sycl_get_rows(ctx, dst);
3566
3701
  break;
3702
+ case GGML_OP_SET:
3703
+ ggml_sycl_op_set(ctx, dst);
3704
+ break;
3567
3705
  case GGML_OP_SET_ROWS:
3568
3706
  ggml_sycl_op_set_rows(ctx, dst);
3569
3707
  break;
@@ -3574,6 +3712,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3574
3712
  case GGML_OP_ADD1: // TODO: more efficient implementation
3575
3713
  ggml_sycl_add(ctx, dst);
3576
3714
  break;
3715
+ case GGML_OP_ADD_ID:
3716
+ ggml_sycl_add_id(ctx, dst);
3717
+ break;
3577
3718
  case GGML_OP_SUB:
3578
3719
  ggml_sycl_sub(ctx, dst);
3579
3720
  break;
@@ -3639,6 +3780,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3639
3780
  case GGML_UNARY_OP_ELU:
3640
3781
  ggml_sycl_elu(ctx, dst);
3641
3782
  break;
3783
+ case GGML_UNARY_OP_FLOOR:
3784
+ ggml_sycl_floor(ctx, dst);
3785
+ break;
3786
+ case GGML_UNARY_OP_CEIL:
3787
+ ggml_sycl_ceil(ctx, dst);
3788
+ break;
3789
+ case GGML_UNARY_OP_ROUND:
3790
+ ggml_sycl_round(ctx, dst);
3791
+ break;
3792
+ case GGML_UNARY_OP_TRUNC:
3793
+ ggml_sycl_trunc(ctx, dst);
3794
+ break;
3642
3795
  default:
3643
3796
  return false;
3644
3797
  }
@@ -3654,6 +3807,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3654
3807
  case GGML_GLU_OP_SWIGLU:
3655
3808
  ggml_sycl_swiglu(ctx, dst);
3656
3809
  break;
3810
+ case GGML_GLU_OP_SWIGLU_OAI:
3811
+ ggml_sycl_swiglu_oai(ctx, dst);
3812
+ break;
3657
3813
  case GGML_GLU_OP_GEGLU_ERF:
3658
3814
  ggml_sycl_geglu_erf(ctx, dst);
3659
3815
  break;
@@ -3673,6 +3829,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3673
3829
  case GGML_OP_CONCAT:
3674
3830
  ggml_sycl_op_concat(ctx, dst);
3675
3831
  break;
3832
+ case GGML_OP_PAD_REFLECT_1D:
3833
+ ggml_sycl_op_pad_reflect_1d(ctx,dst);
3834
+ break;
3676
3835
  case GGML_OP_UPSCALE:
3677
3836
  ggml_sycl_upscale(ctx, dst);
3678
3837
  break;
@@ -3682,6 +3841,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3682
3841
  case GGML_OP_LEAKY_RELU:
3683
3842
  ggml_sycl_leaky_relu(ctx, dst);
3684
3843
  break;
3844
+ case GGML_OP_RMS_NORM_BACK:
3845
+ ggml_sycl_rms_norm_back(ctx, dst);
3846
+ break;
3685
3847
  case GGML_OP_RMS_NORM:
3686
3848
  ggml_sycl_rms_norm(ctx, dst);
3687
3849
  break;
@@ -3741,6 +3903,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3741
3903
  case GGML_OP_SOFT_MAX:
3742
3904
  ggml_sycl_op_soft_max(ctx, dst);
3743
3905
  break;
3906
+ case GGML_OP_SOFT_MAX_BACK:
3907
+ ggml_sycl_op_soft_max_back(ctx, dst);
3908
+ break;
3744
3909
  case GGML_OP_ROPE:
3745
3910
  ggml_sycl_rope(ctx, dst);
3746
3911
  break;
@@ -3756,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3756
3921
  case GGML_OP_SUM_ROWS:
3757
3922
  ggml_sycl_sum_rows(ctx, dst);
3758
3923
  break;
3924
+ case GGML_OP_MEAN:
3925
+ ggml_sycl_mean(ctx, dst);
3926
+ break;
3759
3927
  case GGML_OP_ARGSORT:
3760
3928
  ggml_sycl_argsort(ctx, dst);
3761
3929
  break;
@@ -3771,6 +3939,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3771
3939
  case GGML_OP_GATED_LINEAR_ATTN:
3772
3940
  ggml_sycl_op_gated_linear_attn(ctx, dst);
3773
3941
  break;
3942
+ case GGML_OP_SSM_CONV:
3943
+ ggml_sycl_ssm_conv(ctx, dst);
3944
+ break;
3945
+ case GGML_OP_ROLL:
3946
+ ggml_sycl_roll(ctx, dst);
3947
+ break;
3948
+ case GGML_OP_ARANGE:
3949
+ ggml_sycl_arange(ctx, dst);
3950
+ break;
3774
3951
  default:
3775
3952
  return false;
3776
3953
  }
@@ -3778,6 +3955,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3778
3955
  return true;
3779
3956
  } catch (sycl::exception & e) {
3780
3957
  std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3958
+ std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
3781
3959
  std::exit(1);
3782
3960
  }
3783
3961
 
@@ -3972,6 +4150,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
3972
4150
  GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
3973
4151
  ggml_op_name(node_op));
3974
4152
  return false;
4153
+ case GGML_OP_MUL_MAT:
4154
+ // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
4155
+ // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
4156
+ // in reordering.
4157
+ if (!g_ggml_sycl_use_async_mem_op) {
4158
+ GGML_LOG_INFO(
4159
+ "%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
4160
+ "oneAPI async memory allocation extension "
4161
+ "%s\n",
4162
+ __func__, ggml_op_name(node_op));
4163
+ return false;
4164
+ }
3975
4165
  }
3976
4166
  }
3977
4167
  return true;
@@ -4096,6 +4286,7 @@ struct ggml_backend_sycl_device_context {
4096
4286
  int device;
4097
4287
  std::string name;
4098
4288
  std::string description;
4289
+ int op_offload_min_batch_size;
4099
4290
  };
4100
4291
 
4101
4292
  static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
@@ -4166,6 +4357,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
4166
4357
  }
4167
4358
 
4168
4359
  static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4360
+ ggml_backend_sycl_device_context *sycl_ctx =
4361
+ (ggml_backend_sycl_device_context *)dev->context;
4362
+ int device = sycl_ctx->device;
4169
4363
  switch (op->op) {
4170
4364
  case GGML_OP_CONV_TRANSPOSE_1D:
4171
4365
  {
@@ -4178,21 +4372,26 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4178
4372
  }
4179
4373
  case GGML_OP_UNARY:
4180
4374
  switch (ggml_get_unary_op(op)) {
4375
+ case GGML_UNARY_OP_SGN:
4376
+ case GGML_UNARY_OP_ABS:
4181
4377
  case GGML_UNARY_OP_NEG:
4182
4378
  case GGML_UNARY_OP_STEP:
4379
+ case GGML_UNARY_OP_RELU:
4380
+ case GGML_UNARY_OP_HARDSIGMOID:
4381
+ case GGML_UNARY_OP_TANH:
4183
4382
  case GGML_UNARY_OP_GELU:
4184
4383
  case GGML_UNARY_OP_SILU:
4185
- case GGML_UNARY_OP_RELU:
4186
4384
  case GGML_UNARY_OP_SIGMOID:
4187
- case GGML_UNARY_OP_HARDSIGMOID:
4188
4385
  case GGML_UNARY_OP_HARDSWISH:
4189
4386
  case GGML_UNARY_OP_GELU_QUICK:
4190
4387
  case GGML_UNARY_OP_GELU_ERF:
4191
- case GGML_UNARY_OP_TANH:
4192
4388
  case GGML_UNARY_OP_EXP:
4193
- case GGML_UNARY_OP_SGN:
4194
- case GGML_UNARY_OP_ABS:
4195
4389
  case GGML_UNARY_OP_ELU:
4390
+ return true;
4391
+ case GGML_UNARY_OP_FLOOR:
4392
+ case GGML_UNARY_OP_CEIL:
4393
+ case GGML_UNARY_OP_ROUND:
4394
+ case GGML_UNARY_OP_TRUNC:
4196
4395
  #if defined (GGML_SYCL_F16)
4197
4396
  return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4198
4397
  #else
@@ -4206,6 +4405,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4206
4405
  case GGML_GLU_OP_REGLU:
4207
4406
  case GGML_GLU_OP_GEGLU:
4208
4407
  case GGML_GLU_OP_SWIGLU:
4408
+ case GGML_GLU_OP_SWIGLU_OAI:
4209
4409
  case GGML_GLU_OP_GEGLU_ERF:
4210
4410
  case GGML_GLU_OP_GEGLU_QUICK:
4211
4411
  return ggml_is_contiguous_1(op->src[0]);
@@ -4233,15 +4433,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4233
4433
  }
4234
4434
  }
4235
4435
  ggml_type src0_type = op->src[0]->type;
4236
- if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
4237
- // TODO: support MXFP4
4436
+ if (src0_type == GGML_TYPE_BF16 ) {
4437
+ // TODO: support GGML_TYPE_BF16
4238
4438
  // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4239
4439
  return false;
4240
4440
  }
4441
+
4241
4442
  // TODO: The configuration below needs more work to be supported with oneDNN
4242
- if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
4243
- return false;
4443
+ if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
4444
+ a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
4445
+ return false;
4244
4446
  }
4447
+
4245
4448
  // TODO: This specific configuration can fail with oneDNN and needs more debugging
4246
4449
  if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
4247
4450
  a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
@@ -4266,6 +4469,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4266
4469
  return false;
4267
4470
  }
4268
4471
  }
4472
+ case GGML_OP_SET:
4473
+ return (op->type == GGML_TYPE_F32) &&
4474
+ (op->src[0] && op->src[1]) &&
4475
+ (op->src[0]->type == GGML_TYPE_F32) &&
4476
+ (op->src[1]->type == GGML_TYPE_F32);
4477
+
4269
4478
  case GGML_OP_SET_ROWS:
4270
4479
  {
4271
4480
  return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
@@ -4343,11 +4552,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4343
4552
  }
4344
4553
  return false;
4345
4554
  }
4346
- case GGML_OP_CONCAT:
4555
+ case GGML_OP_REPEAT_BACK:
4347
4556
  {
4348
4557
  ggml_type src0_type = op->src[0]->type;
4349
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4558
+ return src0_type == GGML_TYPE_F32;
4350
4559
  }
4560
+ case GGML_OP_CONCAT:
4351
4561
  case GGML_OP_DUP:
4352
4562
  case GGML_OP_ARGMAX:
4353
4563
  case GGML_OP_NONE:
@@ -4355,15 +4565,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4355
4565
  case GGML_OP_VIEW:
4356
4566
  case GGML_OP_PERMUTE:
4357
4567
  case GGML_OP_TRANSPOSE:
4358
- return true;
4359
4568
  case GGML_OP_ADD:
4360
4569
  case GGML_OP_ADD1:
4570
+ case GGML_OP_ADD_ID:
4361
4571
  case GGML_OP_SUB:
4362
4572
  case GGML_OP_COUNT_EQUAL:
4363
4573
  case GGML_OP_MUL:
4364
4574
  case GGML_OP_DIV:
4365
4575
  case GGML_OP_REPEAT:
4366
4576
  return true;
4577
+ case GGML_OP_PAD_REFLECT_1D:
4578
+ return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
4367
4579
  case GGML_OP_SQR:
4368
4580
  case GGML_OP_SQRT:
4369
4581
  case GGML_OP_SIN:
@@ -4382,44 +4594,56 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4382
4594
  return ggml_is_contiguous(op->src[0]);
4383
4595
  case GGML_OP_RMS_NORM:
4384
4596
  return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
4597
+ case GGML_OP_RMS_NORM_BACK:
4598
+ return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
4385
4599
  case GGML_OP_SCALE:
4386
4600
  return true;
4387
4601
  case GGML_OP_CONT:
4388
4602
  return op->src[0]->type != GGML_TYPE_BF16;
4389
- case GGML_OP_SOFT_MAX:
4390
- // TODO: support batching
4391
- if (op->src[0]->ne[3] != 1) {
4392
- return false;
4393
- }
4394
- // TODO: support attention sinks [TAG_ATTN_SINKS]
4395
- if (op->src[2]) {
4396
- return false;
4397
- }
4398
- // TODO: support broadcast
4399
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
4400
- return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
4401
4603
  case GGML_OP_DIAG_MASK_INF:
4604
+ return true;
4605
+ case GGML_OP_SOFT_MAX:
4606
+ return true;
4607
+ case GGML_OP_SOFT_MAX_BACK: {
4608
+ float max_bias = 0.0f;
4609
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
4610
+ return max_bias == 0.0f;
4611
+ }
4402
4612
  case GGML_OP_ROPE:
4403
4613
  case GGML_OP_IM2COL:
4404
4614
  return true;
4405
4615
  case GGML_OP_UPSCALE:
4406
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4616
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
4407
4617
  case GGML_OP_SUM:
4408
4618
  case GGML_OP_SUM_ROWS:
4409
- case GGML_OP_ARGSORT:
4619
+ case GGML_OP_MEAN:
4410
4620
  return ggml_is_contiguous(op->src[0]);
4621
+ case GGML_OP_ARGSORT:
4622
+ return op->src[0]->ne[0] * sizeof(int) <=
4623
+ ggml_sycl_info().devices[device].smpbo;
4411
4624
  case GGML_OP_POOL_2D:
4412
4625
  case GGML_OP_ACC:
4413
4626
  return true;
4414
4627
  case GGML_OP_PAD:
4415
- return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
4416
- (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
4628
+ // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
4629
+ if (ggml_get_op_params_i32(op, 8) != 0) {
4630
+ return false;
4631
+ }
4632
+ return ggml_is_contiguous(op->src[0]);
4417
4633
  case GGML_OP_LEAKY_RELU:
4418
4634
  case GGML_OP_TIMESTEP_EMBEDDING:
4419
4635
  case GGML_OP_RWKV_WKV6:
4420
4636
  case GGML_OP_RWKV_WKV7:
4421
4637
  case GGML_OP_GATED_LINEAR_ATTN:
4422
4638
  return true;
4639
+ case GGML_OP_SSM_CONV:
4640
+ return op->type == GGML_TYPE_F32 &&
4641
+ op->src[0]->type == GGML_TYPE_F32 &&
4642
+ op->src[1]->type == GGML_TYPE_F32;
4643
+ case GGML_OP_ROLL:
4644
+ return op->type == GGML_TYPE_F32;
4645
+ case GGML_OP_ARANGE:
4646
+ return op->type == GGML_TYPE_F32;
4423
4647
  default:
4424
4648
  return false;
4425
4649
  }
@@ -4451,9 +4675,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
4451
4675
  }
4452
4676
 
4453
4677
  static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4454
- const int min_batch_size = 32;
4455
- return get_op_batch_size(op) >= min_batch_size;
4456
- GGML_UNUSED(dev);
4678
+ ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
4679
+ return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
4457
4680
  }
4458
4681
 
4459
4682
  static ggml_backend_event_t
@@ -4576,6 +4799,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
4576
4799
  std::lock_guard<std::mutex> lock(mutex);
4577
4800
  if (!initialized) {
4578
4801
  ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
4802
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
4579
4803
 
4580
4804
  for (int i = 0; i < ggml_sycl_info().device_count; i++) {
4581
4805
  ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
@@ -4589,6 +4813,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
4589
4813
  prop, dpct::dev_mgr::instance().get_device(i))));
4590
4814
 
4591
4815
  dev_ctx->description = prop.get_name();
4816
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
4592
4817
 
4593
4818
  ggml_backend_dev_t dev = new ggml_backend_device {
4594
4819
  /* .iface = */ ggml_backend_sycl_device_interface,