whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -30,22 +30,32 @@
30
30
  #include <regex>
31
31
 
32
32
  #include <sycl/sycl.hpp>
33
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
34
+ # include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
35
+ #endif
33
36
  #include <sycl/half_type.hpp>
34
37
 
38
+ #include "ggml.h"
35
39
  #include "ggml-sycl.h"
36
40
  #include "ggml-impl.h"
37
41
  #include "ggml-backend-impl.h"
38
42
 
43
+ #include "ggml-sycl/add-id.hpp"
39
44
  #include "ggml-sycl/backend.hpp"
40
45
  #include "ggml-sycl/common.hpp"
41
46
  #include "ggml-sycl/element_wise.hpp"
42
- #include "ggml-sycl/presets.hpp"
47
+ #include "ggml-sycl/gated_delta_net.hpp"
43
48
  #include "ggml-sycl/gemm.hpp"
44
- #include "ggml-sycl/set_rows.hpp"
45
- #include "ggml-sycl/sycl_hw.hpp"
46
49
  #include "ggml-sycl/getrows.hpp"
50
+ #include "ggml-sycl/norm.hpp"
51
+ #include "ggml-sycl/presets.hpp"
47
52
  #include "ggml-sycl/quantize.hpp"
48
- #include "ggml.h"
53
+ #include "ggml-sycl/repeat_back.hpp"
54
+ #include "ggml-sycl/set_rows.hpp"
55
+ #include "ggml-sycl/set.hpp"
56
+ #include "ggml-sycl/ssm_conv.hpp"
57
+ #include "ggml-sycl/sycl_hw.hpp"
58
+
49
59
 
50
60
  static bool g_sycl_loaded = false;
51
61
  int g_ggml_sycl_debug = 0;
@@ -53,6 +63,9 @@ int g_ggml_sycl_disable_optimize = 0;
53
63
  int g_ggml_sycl_disable_graph = 0;
54
64
  int g_ggml_sycl_disable_dnn = 0;
55
65
  int g_ggml_sycl_prioritize_dmmv = 0;
66
+ int g_ggml_sycl_use_async_mem_op = 0;
67
+ int g_ggml_sycl_enable_flash_attention = 1;
68
+
56
69
 
57
70
  static ggml_sycl_device_info ggml_sycl_init() {
58
71
  ggml_sycl_device_info info = {};
@@ -85,8 +98,14 @@ static ggml_sycl_device_info ggml_sycl_init() {
85
98
 
86
99
  info.devices[i].cc =
87
100
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
101
+ info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
88
102
  info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
103
+ info.devices[i].smpbo = prop.get_local_mem_size();
104
+ info.devices[i].warp_size = WARP_SIZE;
105
+
89
106
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
107
+ info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
108
+
90
109
  }
91
110
 
92
111
  for (int id = 0; id < info.device_count; ++id) {
@@ -199,7 +218,37 @@ static void ggml_check_sycl() try {
199
218
  g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
200
219
  g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
201
220
  g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
221
+
222
+ #ifdef SYCL_FLASH_ATTN
223
+ g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
224
+ #else
225
+ g_ggml_sycl_enable_flash_attention = 0;
226
+ #endif
227
+
202
228
  GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
229
+
230
+ GGML_LOG_INFO("Build with Macros:\n");
231
+ #if defined(GGML_SYCL_FORCE_MMQ)
232
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
233
+ #else
234
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
235
+ #endif
236
+ #if defined(GGML_SYCL_F16)
237
+ GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
238
+ #else
239
+ GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
240
+ #endif
241
+ #if defined(GGML_SYCL_GRAPH)
242
+ GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n");
243
+ #else
244
+ GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n");
245
+ #endif
246
+ #if defined(GGML_SYCL_DNNL)
247
+ GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n");
248
+ #else
249
+ GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n");
250
+ #endif
251
+
203
252
  GGML_LOG_INFO("Running with Environment Variables:\n");
204
253
  GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
205
254
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
@@ -214,16 +263,12 @@ static void ggml_check_sycl() try {
214
263
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
215
264
  #endif
216
265
  GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
217
- GGML_LOG_INFO("Build with Macros:\n");
218
- #if defined(GGML_SYCL_FORCE_MMQ)
219
- GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
220
- #else
221
- GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
222
- #endif
223
- #if defined(GGML_SYCL_F16)
224
- GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
266
+
267
+ #ifdef SYCL_FLASH_ATTN
268
+ GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
225
269
  #else
226
- GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
270
+ GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
271
+ g_ggml_sycl_enable_flash_attention);
227
272
  #endif
228
273
 
229
274
  /* NOT REMOVE, keep it for next optimize for XMX.
@@ -233,7 +278,20 @@ static void ggml_check_sycl() try {
233
278
  fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
234
279
  #endif
235
280
  */
236
-
281
+ // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
282
+ // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
283
+ // other places.
284
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
285
+ g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
286
+ if (g_ggml_sycl_use_async_mem_op) {
287
+ for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
288
+ if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
289
+ g_ggml_sycl_use_async_mem_op = 0;
290
+ break;
291
+ }
292
+ }
293
+ }
294
+ #endif
237
295
  if (CHECK_TRY_ERROR(g_all_sycl_device_count =
238
296
  dpct::dev_mgr::instance().device_count()) != 0) {
239
297
  initialized = true;
@@ -1132,13 +1190,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
1132
1190
  GGML_UNUSED(buft);
1133
1191
  }
1134
1192
 
1193
+ inline void * aligned_malloc_host(size_t alignment, size_t size) {
1194
+ #ifdef _WIN32
1195
+ return _aligned_malloc(size, alignment);
1196
+ #else
1197
+ return aligned_alloc(alignment, size);
1198
+ #endif
1199
+ }
1200
+
1201
+ inline void free_aligned_mem_host(void * memblock) {
1202
+ #ifdef _WIN32
1203
+ _aligned_free(memblock);
1204
+ #else
1205
+ free(memblock);
1206
+ #endif
1207
+ }
1208
+
1135
1209
  static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1136
- ggml_sycl_host_free(buffer->context);
1210
+ free_aligned_mem_host((void *)buffer->context);
1137
1211
  }
1138
1212
 
1139
1213
  static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1140
- void * ptr = ggml_sycl_host_malloc(size);
1141
-
1214
+ void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
1142
1215
  if (ptr == nullptr) {
1143
1216
  // fallback to cpu buffer
1144
1217
  return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
@@ -1511,60 +1584,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
1511
1584
  template <ggml_sort_order order>
1512
1585
  __dpct_inline__ static void
1513
1586
  k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
1514
- const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
1587
+ const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
1588
+ uint8_t *dpct_local) {
1515
1589
  // bitonic sort
1516
- int col = item_ct1.get_local_id(2);
1590
+ int col_index = item_ct1.get_local_id(2);
1517
1591
  int row = item_ct1.get_group(1);
1518
1592
 
1519
- if (col >= ncols_pad) {
1520
- return;
1593
+ for (int i = 0; i < tasks_per_thread; i++) {
1594
+ int col = col_index * tasks_per_thread + i;
1595
+ if (col >= ncols_pad) {
1596
+ return;
1597
+ }
1521
1598
  }
1522
1599
 
1523
1600
  const float * x_row = x + row * ncols;
1524
1601
  auto dst_row = (int *)dpct_local;
1525
1602
 
1526
1603
  // initialize indices
1527
- dst_row[col] = col;
1604
+ for (int i=0;i<tasks_per_thread;i++){
1605
+ int col = col_index*tasks_per_thread+i;
1606
+ dst_row[col] = col;
1607
+ }
1528
1608
 
1529
1609
  item_ct1.barrier(sycl::access::fence_space::local_space);
1530
1610
 
1531
1611
  for (int k = 2; k <= ncols_pad; k *= 2) {
1532
1612
  for (int j = k / 2; j > 0; j /= 2) {
1533
- int ixj = col ^ j;
1534
- if (ixj > col) {
1535
- if ((col & k) == 0) {
1536
- if (dst_row[col] >= ncols ||
1537
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
1538
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
1539
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
1540
- ) {
1541
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1542
- }
1543
- } else {
1544
- if (dst_row[ixj] >= ncols ||
1545
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
1546
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
1547
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
1548
- ) {
1549
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1613
+ for (int i = 0; i < tasks_per_thread; i++) {
1614
+ int col = col_index * tasks_per_thread + i;
1615
+ int ixj = col ^ j;
1616
+ if (ixj > col) {
1617
+ if ((col & k) == 0) {
1618
+ if (dst_row[col] >= ncols ||
1619
+ (dst_row[ixj] < ncols &&
1620
+ (order == GGML_SORT_ORDER_ASC
1621
+ ? x_row[dst_row[col]] > x_row[dst_row[ixj]]
1622
+ : x_row[dst_row[col]] <
1623
+ x_row[dst_row[ixj]]))) {
1624
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1625
+ }
1626
+ } else {
1627
+ if (dst_row[ixj] >= ncols ||
1628
+ (dst_row[col] < ncols &&
1629
+ (order == GGML_SORT_ORDER_ASC
1630
+ ? x_row[dst_row[col]] < x_row[dst_row[ixj]]
1631
+ : x_row[dst_row[col]] >
1632
+ x_row[dst_row[ixj]]))) {
1633
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1634
+ }
1550
1635
  }
1551
1636
  }
1637
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1552
1638
  }
1553
- /*
1554
- DPCT1118:1: SYCL group functions and algorithms must be encountered
1555
- in converged control flow. You may need to adjust the code.
1556
- */
1557
- item_ct1.barrier(sycl::access::fence_space::local_space);
1558
1639
  }
1559
1640
  }
1560
1641
 
1561
1642
  // copy the result to dst without the padding
1562
- if (col < ncols) {
1563
- dst[row * ncols + col] = dst_row[col];
1643
+ for (int i = 0; i < tasks_per_thread; i++) {
1644
+ int col = col_index * tasks_per_thread + i;
1645
+ if (col < ncols) {
1646
+ dst[row * ncols + col] = dst_row[col];
1647
+ }
1564
1648
  }
1565
1649
  }
1566
1650
 
1567
-
1568
1651
  static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
1569
1652
  const sycl::nd_item<3> &item_ct1) {
1570
1653
  const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -1737,13 +1820,23 @@ static int next_power_of_2(int x) {
1737
1820
 
1738
1821
  static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1739
1822
  const int nrows, ggml_sort_order order,
1740
- queue_ptr stream) {
1823
+ queue_ptr stream, int device) {
1741
1824
  // bitonic sort requires ncols to be power of 2
1742
1825
  const int ncols_pad = next_power_of_2(ncols);
1743
1826
 
1744
- const sycl::range<3> block_dims(1, 1, ncols_pad);
1827
+ int nth = 1;
1828
+ int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
1829
+ while (nth < ncols_pad && nth < max_block_size)
1830
+ nth *= 2;
1831
+ if (nth > max_block_size)
1832
+ nth = max_block_size;
1833
+
1834
+ const int tasks_per_thread = ncols_pad / nth;
1835
+
1836
+ const sycl::range<3> block_dims(1, 1, nth);
1745
1837
  const sycl::range<3> block_nums(1, nrows, 1);
1746
1838
  const size_t shared_mem = ncols_pad * sizeof(int);
1839
+ GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
1747
1840
 
1748
1841
  if (order == GGML_SORT_ORDER_ASC) {
1749
1842
  stream->submit([&](sycl::handler &cgh) {
@@ -1754,8 +1847,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1754
1847
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1755
1848
  [=](sycl::nd_item<3> item_ct1) {
1756
1849
  k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
1757
- x, dst, ncols, ncols_pad, item_ct1,
1758
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
1850
+ x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1851
+ dpct_local_acc_ct1
1852
+ .get_multi_ptr<sycl::access::decorated::no>()
1759
1853
  .get());
1760
1854
  });
1761
1855
  });
@@ -1768,8 +1862,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1768
1862
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1769
1863
  [=](sycl::nd_item<3> item_ct1) {
1770
1864
  k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
1771
- x, dst, ncols, ncols_pad, item_ct1,
1772
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
1865
+ x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1866
+ dpct_local_acc_ct1
1867
+ .get_multi_ptr<sycl::access::decorated::no>()
1773
1868
  .get());
1774
1869
  });
1775
1870
  });
@@ -1778,6 +1873,110 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1778
1873
  }
1779
1874
  }
1780
1875
 
1876
+ static void top_k_f32_sycl(
1877
+ const float * src,
1878
+ int32_t * dst_indices,
1879
+ const int64_t ncols,
1880
+ const int64_t nrows,
1881
+ const int k,
1882
+ dpct::queue_ptr main_stream
1883
+ ) {
1884
+ const int block_size = 128;
1885
+
1886
+ const sycl::range<1> block_dims(block_size);
1887
+ const sycl::range<1> grid_dims(nrows);
1888
+
1889
+ main_stream->submit([&](sycl::handler &cgh) {
1890
+ sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);
1891
+ sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);
1892
+
1893
+ cgh.parallel_for(
1894
+ sycl::nd_range<1>(grid_dims * block_dims, block_dims),
1895
+ [=](sycl::nd_item<1> item_ct1) {
1896
+ const int row = item_ct1.get_group(0);
1897
+ const int tid = item_ct1.get_local_id(0);
1898
+
1899
+ if (row >= nrows) return;
1900
+
1901
+ const float * src_row = src + row * ncols;
1902
+ int32_t * dst_idx_row = dst_indices + row * k;
1903
+
1904
+ float local_vals[32];
1905
+ int local_idx[32];
1906
+
1907
+ for (int i = 0; i < k; i++) {
1908
+ local_vals[i] = -FLT_MAX;
1909
+ local_idx[i] = -1;
1910
+ }
1911
+
1912
+ for (int col = tid; col < ncols; col += block_size) {
1913
+ float val = src_row[col];
1914
+
1915
+ if (val > local_vals[k-1]) {
1916
+ int pos = k - 1;
1917
+ while (pos > 0 && val > local_vals[pos - 1]) {
1918
+ pos--;
1919
+ }
1920
+
1921
+ for (int i = k - 1; i > pos; i--) {
1922
+ local_vals[i] = local_vals[i - 1];
1923
+ local_idx[i] = local_idx[i - 1];
1924
+ }
1925
+ local_vals[pos] = val;
1926
+ local_idx[pos] = col;
1927
+ }
1928
+ }
1929
+
1930
+ for (int i = 0; i < k; i++) {
1931
+ shared_vals[tid * k + i] = local_vals[i];
1932
+ shared_idx[tid * k + i] = local_idx[i];
1933
+ }
1934
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1935
+
1936
+ if (tid == 0) {
1937
+ float final_vals[32];
1938
+ int final_idx[32];
1939
+
1940
+ for (int i = 0; i < k; i++) {
1941
+ final_vals[i] = -FLT_MAX;
1942
+ final_idx[i] = -1;
1943
+ }
1944
+
1945
+ for (int t = 0; t < block_size; t++) {
1946
+ for (int i = 0; i < k; i++) {
1947
+ float val = shared_vals[t * k + i];
1948
+ int idx = shared_idx[t * k + i];
1949
+
1950
+ if (val > final_vals[k-1]) {
1951
+ int pos = k - 1;
1952
+ while (pos > 0 && val > final_vals[pos - 1]) {
1953
+ pos--;
1954
+ }
1955
+
1956
+ for (int j = k - 1; j > pos; j--) {
1957
+ final_vals[j] = final_vals[j - 1];
1958
+ final_idx[j] = final_idx[j - 1];
1959
+ }
1960
+ final_vals[pos] = val;
1961
+ final_idx[pos] = idx;
1962
+ }
1963
+ }
1964
+ }
1965
+
1966
+ for (int i = 0; i < k; i++) {
1967
+ dst_idx_row[i] = final_idx[i];
1968
+ }
1969
+
1970
+ if (k > 1) {
1971
+ int32_t temp = dst_idx_row[0];
1972
+ dst_idx_row[0] = dst_idx_row[1];
1973
+ dst_idx_row[1] = temp;
1974
+ }
1975
+ }
1976
+ });
1977
+ });
1978
+ }
1979
+
1781
1980
  static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
1782
1981
  const int nrows, queue_ptr stream) {
1783
1982
  const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
@@ -2001,8 +2200,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2001
2200
  const sycl::half alpha_f16 = 1.0f;
2002
2201
  const sycl::half beta_f16 = 0.0f;
2003
2202
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2004
- *stream, oneapi::math::transpose::trans,
2005
- oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2203
+ *stream, oneapi::mkl::transpose::trans,
2204
+ oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2006
2205
  &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2007
2206
  src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2008
2207
  dst_f16.get(), dpct::library_data_t::real_half, ldc,
@@ -2045,8 +2244,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2045
2244
  {
2046
2245
  const float alpha = 1.0f;
2047
2246
  const float beta = 0.0f;
2048
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2049
- get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2247
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2248
+ *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,
2050
2249
  src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2051
2250
  dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2052
2251
  }
@@ -2127,6 +2326,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
2127
2326
  sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2128
2327
  }
2129
2328
 
2329
+ inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2330
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2331
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2332
+
2333
+ dpct::queue_ptr main_stream = ctx.stream();
2334
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2335
+
2336
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2337
+ float * dst_dd = static_cast<float *>(dst->data);
2338
+
2339
+ const int64_t ncols = dst->src[0]->ne[0];
2340
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2341
+
2342
+ sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2343
+
2344
+ main_stream->parallel_for(
2345
+ sycl::range<1>(nrows),
2346
+ [=](sycl::id<1> row) {
2347
+ dst_dd[row] /= ncols;
2348
+ }
2349
+ );
2350
+ }
2351
+
2352
+
2130
2353
  inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2131
2354
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2132
2355
  GGML_ASSERT(dst->type == GGML_TYPE_I32);
@@ -2141,7 +2364,32 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
2141
2364
 
2142
2365
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2143
2366
 
2144
- argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
2367
+ argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
2368
+ main_stream, ctx.device);
2369
+ }
2370
+
2371
+ static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2372
+ const ggml_tensor * src0 = dst->src[0];
2373
+
2374
+ GGML_ASSERT(src0);
2375
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2376
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
2377
+ GGML_ASSERT(ggml_is_contiguous(src0));
2378
+
2379
+ dpct::queue_ptr main_stream = ctx.stream();
2380
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2381
+
2382
+ const float * src0_dd = static_cast<const float *>(src0->data);
2383
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2384
+
2385
+ const int k = dst->ne[0];
2386
+ const int64_t ncols = src0->ne[0];
2387
+ const int64_t nrows = ggml_nrows(src0);
2388
+
2389
+ GGML_ASSERT(k > 0 && k <= 32);
2390
+ GGML_ASSERT(k <= ncols);
2391
+
2392
+ top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
2145
2393
  }
2146
2394
 
2147
2395
  inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -2176,6 +2424,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
2176
2424
  diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2177
2425
  }
2178
2426
 
2427
+ static void tri_f32_sycl(
2428
+ const float * src,
2429
+ float * dst,
2430
+ const int64_t ne0,
2431
+ const int64_t ne1,
2432
+ const int64_t ne2,
2433
+ const int64_t ne3,
2434
+ const ggml_tri_type ttype,
2435
+ dpct::queue_ptr main_stream
2436
+ ) {
2437
+ const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
2438
+
2439
+ main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
2440
+ const int64_t idx = (int64_t) tid[0];
2441
+
2442
+ const int64_t i0 = idx % ne0;
2443
+ const int64_t t1 = idx / ne0;
2444
+ const int64_t i1 = t1 % ne1;
2445
+
2446
+ bool keep = false;
2447
+ switch (ttype) {
2448
+ case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
2449
+ case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
2450
+ case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
2451
+ case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
2452
+ default: keep = false; break;
2453
+ }
2454
+
2455
+ dst[idx] = keep ? src[idx] : 0.0f;
2456
+ });
2457
+ }
2458
+
2459
+ static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2460
+ const ggml_tensor * src0 = dst->src[0];
2461
+ GGML_ASSERT(src0);
2462
+
2463
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2464
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2465
+ GGML_ASSERT(ggml_is_contiguous(src0));
2466
+ GGML_ASSERT(ggml_is_contiguous(dst));
2467
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
2468
+
2469
+ dpct::queue_ptr main_stream = ctx.stream();
2470
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2471
+
2472
+ const float * src0_dd = static_cast<const float *>(src0->data);
2473
+ float * dst_dd = static_cast<float *>(dst->data);
2474
+
2475
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2476
+
2477
+ const int64_t ne0 = src0->ne[0];
2478
+ const int64_t ne1 = src0->ne[1];
2479
+ const int64_t ne2 = src0->ne[2];
2480
+ const int64_t ne3 = src0->ne[3];
2481
+
2482
+ tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
2483
+ }
2484
+
2485
+
2179
2486
  inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2180
2487
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2181
2488
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -2548,6 +2855,10 @@ catch (sycl::exception const &exc) {
2548
2855
  std::exit(1);
2549
2856
  }
2550
2857
 
2858
+ static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2859
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2860
+ ggml_sycl_op_repeat_back(ctx, dst);
2861
+ }
2551
2862
 
2552
2863
  static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2553
2864
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -2564,6 +2875,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
2564
2875
  ggml_sycl_op_rms_norm(ctx, dst);
2565
2876
  }
2566
2877
 
2878
+ static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2879
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2880
+ ggml_sycl_op_rms_norm_back(ctx, dst);
2881
+ }
2882
+
2567
2883
  static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2568
2884
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2569
2885
  ggml_sycl_op_l2_norm(ctx, dst);
@@ -2729,7 +3045,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2729
3045
 
2730
3046
  }
2731
3047
  #if GGML_SYCL_DNNL
2732
- // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
3048
+ // oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl
2733
3049
  const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
2734
3050
  src1_f16_alloc.alloc(ne_src1);
2735
3051
  const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
@@ -2738,7 +3054,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2738
3054
  # else
2739
3055
  const int64_t ne_src1 = ggml_nelements(src1);
2740
3056
  src1_f16_alloc.alloc(ne_src1);
2741
- const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
3057
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
2742
3058
  GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2743
3059
  to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2744
3060
  #endif
@@ -2882,8 +3198,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2882
3198
  const int64_t smb = ne12 == 1 ? s13 : s12;
2883
3199
 
2884
3200
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2885
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2886
- oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
3201
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,
3202
+ oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
2887
3203
  src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
2888
3204
  src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
2889
3205
  mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
@@ -2907,7 +3223,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2907
3223
  });
2908
3224
 
2909
3225
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2910
- *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
3226
+ *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
2911
3227
  (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2912
3228
  (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2913
3229
  (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
@@ -2981,19 +3297,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2981
3297
  }
2982
3298
  }
2983
3299
 
3300
+ // Helper functions to unify device memory allocation for both async and sync paths
3301
+ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
3302
+ bool use_async = g_ggml_sycl_use_async_mem_op;
3303
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3304
+ if (use_async) {
3305
+ return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
3306
+ }
3307
+ #else
3308
+ // If async allocation extension is not available, use_async should always be false.
3309
+ GGML_ASSERT(!use_async);
3310
+ #endif
3311
+ return sycl::malloc(size, *stream, sycl::usm::alloc::device);
3312
+ }
3313
+
3314
+ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
3315
+ bool use_async = g_ggml_sycl_use_async_mem_op;
3316
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3317
+ if (use_async) {
3318
+ syclex::async_free(*stream, ptr);
3319
+ return;
3320
+ }
3321
+ #else
3322
+ // If async allocation extension is not available, use_async should always be false.
3323
+ GGML_ASSERT(!use_async);
3324
+ #endif
3325
+ sycl::free(ptr, *stream);
3326
+ }
3327
+
2984
3328
  static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
2985
3329
  dpct::queue_ptr stream) {
2986
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2987
- SYCL_CHECK(
2988
- CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2989
- .wait()));
3330
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3331
+
3332
+ sycl::event copy_event;
3333
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3334
+ if (!g_ggml_sycl_use_async_mem_op) {
3335
+ copy_event.wait();
3336
+ }
3337
+
2990
3338
  GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2991
3339
  GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2992
3340
  int offset_blks = offset / sizeof(block_q4_0);
2993
3341
  auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
2994
3342
  auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2995
3343
 
2996
- stream->parallel_for(
3344
+ auto reorder_event = stream->parallel_for(
2997
3345
  size / sizeof(block_q4_0),
2998
3346
  [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2999
3347
  const block_q4_0* x = (const block_q4_0*)tmp_buf;
@@ -3004,9 +3352,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
3004
3352
  *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3005
3353
  }
3006
3354
  *(d_ptr + ib) = x[ib].d;
3007
- }).wait_and_throw();
3008
-
3009
- sycl::free(tmp_buf, *stream);
3355
+ });
3356
+ if (!g_ggml_sycl_use_async_mem_op) {
3357
+ reorder_event.wait_and_throw();
3358
+ }
3359
+ sycl_ext_free(stream, tmp_buf);
3010
3360
  }
3011
3361
 
3012
3362
  static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3015,14 +3365,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3015
3365
 
3016
3366
  const int nblocks = size / sizeof(block_q4_K);
3017
3367
 
3018
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3019
- SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3368
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3369
+
3370
+ sycl::event copy_event;
3371
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3372
+ if (!g_ggml_sycl_use_async_mem_op) {
3373
+ copy_event.wait();
3374
+ }
3020
3375
 
3021
3376
  auto * qs_ptr = data_device;
3022
3377
  auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3023
3378
  auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3024
3379
 
3025
- stream->parallel_for(nblocks, [=](auto i) {
3380
+ auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3026
3381
  const block_q4_K * x = (const block_q4_K *) tmp_buf;
3027
3382
  const int ib = i;
3028
3383
 
@@ -3035,9 +3390,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3035
3390
  }
3036
3391
 
3037
3392
  dm_ptr[ib] = x[ib].dm;
3038
- }).wait_and_throw();
3039
-
3040
- sycl::free(tmp_buf, *stream);
3393
+ });
3394
+ if (!g_ggml_sycl_use_async_mem_op) {
3395
+ reorder_event.wait_and_throw();
3396
+ }
3397
+ sycl_ext_free(stream, tmp_buf);
3041
3398
  }
3042
3399
 
3043
3400
  static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3046,42 +3403,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
3046
3403
 
3047
3404
  const int nblocks = size / sizeof(block_q6_K);
3048
3405
 
3049
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3050
- SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3406
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3407
+
3408
+ sycl::event copy_event;
3409
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3410
+ if (!g_ggml_sycl_use_async_mem_op) {
3411
+ copy_event.wait();
3412
+ }
3051
3413
 
3052
3414
  auto * ql_ptr = data_device;
3053
3415
  auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3054
3416
  auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3055
3417
  sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3056
3418
 
3057
- stream
3058
- ->parallel_for(nblocks,
3059
- [=](auto i) {
3060
- const block_q6_K * x = (const block_q6_K *) tmp_buf;
3061
- const int ib = i;
3062
-
3063
- const uint8_t * ql = x[ib].ql;
3064
- const uint8_t * qh = x[ib].qh;
3065
- uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3066
- uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3067
- uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3419
+ auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3420
+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3421
+ const int ib = i;
3068
3422
 
3069
- for (int j = 0; j < QK_K / 2; ++j) {
3070
- base_ql_ptr[j] = ql[j];
3071
- }
3072
- for (int j = 0; j < QK_K / 4; ++j) {
3073
- base_qh_ptr[j] = qh[j];
3074
- }
3423
+ const uint8_t * ql = x[ib].ql;
3424
+ const uint8_t * qh = x[ib].qh;
3425
+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3426
+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3427
+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3075
3428
 
3076
- for (int j = 0; j < QK_K / 16; ++j) {
3077
- base_scales_ptr[j] = x[ib].scales[j];
3078
- }
3429
+ for (int j = 0; j < QK_K / 2; ++j) {
3430
+ base_ql_ptr[j] = ql[j];
3431
+ }
3432
+ for (int j = 0; j < QK_K / 4; ++j) {
3433
+ base_qh_ptr[j] = qh[j];
3434
+ }
3079
3435
 
3080
- dm_ptr[ib] = x[ib].d;
3081
- })
3082
- .wait_and_throw();
3436
+ for (int j = 0; j < QK_K / 16; ++j) {
3437
+ base_scales_ptr[j] = x[ib].scales[j];
3438
+ }
3083
3439
 
3084
- sycl::free(tmp_buf, *stream);
3440
+ dm_ptr[ib] = x[ib].d;
3441
+ });
3442
+ if (!g_ggml_sycl_use_async_mem_op) {
3443
+ reorder_event.wait_and_throw();
3444
+ }
3445
+ sycl_ext_free(stream, tmp_buf);
3085
3446
  }
3086
3447
 
3087
3448
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
@@ -3188,20 +3549,19 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3188
3549
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3189
3550
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
3190
3551
 
3552
+
3191
3553
  // mmvq and mmq need the __dp4a instruction which is available for gen12+
3192
- // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
3554
+ // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
3193
3555
  use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
3194
3556
  #ifdef SYCL_USE_XMX
3195
3557
  use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
3196
3558
  #endif // SYCL_USE_XMX
3197
3559
 
3198
-
3199
- // mmvq path is faster in the CUDA backend.
3200
- if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
3201
- // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3202
- // is enabled takes precedence over DMMV, the current if-else implementation
3203
- // requires disabling DMMV if both conditions are met
3204
- || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
3560
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3561
+ // is enabled takes precedence over DMMV, the current if-else implementation
3562
+ // requires disabling DMMV if both conditions are met
3563
+ if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
3564
+ ggml_sycl_supports_reorder_mmvq(src0->type)))) {
3205
3565
  use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3206
3566
  }
3207
3567
 
@@ -3510,6 +3870,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
3510
3870
  ggml_sycl_op_sum_rows(ctx, dst);
3511
3871
  }
3512
3872
 
3873
+ static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3874
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3875
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3876
+ ggml_sycl_op_mean(ctx, dst);
3877
+ }
3878
+
3513
3879
  static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3514
3880
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3515
3881
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -3561,9 +3927,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3561
3927
  case GGML_OP_REPEAT:
3562
3928
  ggml_sycl_repeat(ctx, dst);
3563
3929
  break;
3930
+ case GGML_OP_REPEAT_BACK:
3931
+ ggml_sycl_repeat_back(ctx, dst);
3932
+ break;
3564
3933
  case GGML_OP_GET_ROWS:
3565
3934
  ggml_sycl_get_rows(ctx, dst);
3566
3935
  break;
3936
+ case GGML_OP_SET:
3937
+ ggml_sycl_op_set(ctx, dst);
3938
+ break;
3567
3939
  case GGML_OP_SET_ROWS:
3568
3940
  ggml_sycl_op_set_rows(ctx, dst);
3569
3941
  break;
@@ -3574,6 +3946,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3574
3946
  case GGML_OP_ADD1: // TODO: more efficient implementation
3575
3947
  ggml_sycl_add(ctx, dst);
3576
3948
  break;
3949
+ case GGML_OP_ADD_ID:
3950
+ ggml_sycl_add_id(ctx, dst);
3951
+ break;
3577
3952
  case GGML_OP_SUB:
3578
3953
  ggml_sycl_sub(ctx, dst);
3579
3954
  break;
@@ -3630,6 +4005,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3630
4005
  case GGML_UNARY_OP_EXP:
3631
4006
  ggml_sycl_exp(ctx, dst);
3632
4007
  break;
4008
+ case GGML_UNARY_OP_SOFTPLUS:
4009
+ ggml_sycl_softplus(ctx, dst);
4010
+ break;
3633
4011
  case GGML_UNARY_OP_SGN:
3634
4012
  ggml_sycl_sgn(ctx, dst);
3635
4013
  break;
@@ -3639,6 +4017,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3639
4017
  case GGML_UNARY_OP_ELU:
3640
4018
  ggml_sycl_elu(ctx, dst);
3641
4019
  break;
4020
+ case GGML_UNARY_OP_FLOOR:
4021
+ ggml_sycl_floor(ctx, dst);
4022
+ break;
4023
+ case GGML_UNARY_OP_CEIL:
4024
+ ggml_sycl_ceil(ctx, dst);
4025
+ break;
4026
+ case GGML_UNARY_OP_ROUND:
4027
+ ggml_sycl_round(ctx, dst);
4028
+ break;
4029
+ case GGML_UNARY_OP_TRUNC:
4030
+ ggml_sycl_trunc(ctx, dst);
4031
+ break;
3642
4032
  default:
3643
4033
  return false;
3644
4034
  }
@@ -3654,6 +4044,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3654
4044
  case GGML_GLU_OP_SWIGLU:
3655
4045
  ggml_sycl_swiglu(ctx, dst);
3656
4046
  break;
4047
+ case GGML_GLU_OP_SWIGLU_OAI:
4048
+ ggml_sycl_swiglu_oai(ctx, dst);
4049
+ break;
3657
4050
  case GGML_GLU_OP_GEGLU_ERF:
3658
4051
  ggml_sycl_geglu_erf(ctx, dst);
3659
4052
  break;
@@ -3673,6 +4066,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3673
4066
  case GGML_OP_CONCAT:
3674
4067
  ggml_sycl_op_concat(ctx, dst);
3675
4068
  break;
4069
+ case GGML_OP_PAD_REFLECT_1D:
4070
+ ggml_sycl_op_pad_reflect_1d(ctx,dst);
4071
+ break;
3676
4072
  case GGML_OP_UPSCALE:
3677
4073
  ggml_sycl_upscale(ctx, dst);
3678
4074
  break;
@@ -3682,6 +4078,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3682
4078
  case GGML_OP_LEAKY_RELU:
3683
4079
  ggml_sycl_leaky_relu(ctx, dst);
3684
4080
  break;
4081
+ case GGML_OP_RMS_NORM_BACK:
4082
+ ggml_sycl_rms_norm_back(ctx, dst);
4083
+ break;
3685
4084
  case GGML_OP_RMS_NORM:
3686
4085
  ggml_sycl_rms_norm(ctx, dst);
3687
4086
  break;
@@ -3735,15 +4134,24 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3735
4134
  case GGML_OP_TRANSPOSE:
3736
4135
  GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
3737
4136
  break;
4137
+ case GGML_OP_TRI:
4138
+ ggml_sycl_op_tri(ctx, dst);
4139
+ break;
3738
4140
  case GGML_OP_DIAG_MASK_INF:
3739
4141
  ggml_sycl_diag_mask_inf(ctx, dst);
3740
4142
  break;
3741
4143
  case GGML_OP_SOFT_MAX:
3742
4144
  ggml_sycl_op_soft_max(ctx, dst);
3743
4145
  break;
4146
+ case GGML_OP_SOFT_MAX_BACK:
4147
+ ggml_sycl_op_soft_max_back(ctx, dst);
4148
+ break;
3744
4149
  case GGML_OP_ROPE:
3745
4150
  ggml_sycl_rope(ctx, dst);
3746
4151
  break;
4152
+ case GGML_OP_ROPE_BACK:
4153
+ ggml_sycl_rope_back(ctx, dst);
4154
+ break;
3747
4155
  case GGML_OP_IM2COL:
3748
4156
  ggml_sycl_im2col(ctx, dst);
3749
4157
  break;
@@ -3756,9 +4164,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3756
4164
  case GGML_OP_SUM_ROWS:
3757
4165
  ggml_sycl_sum_rows(ctx, dst);
3758
4166
  break;
4167
+ case GGML_OP_MEAN:
4168
+ ggml_sycl_mean(ctx, dst);
4169
+ break;
3759
4170
  case GGML_OP_ARGSORT:
3760
4171
  ggml_sycl_argsort(ctx, dst);
3761
4172
  break;
4173
+ case GGML_OP_TOP_K:
4174
+ ggml_sycl_op_top_k(ctx, dst);
4175
+ break;
3762
4176
  case GGML_OP_TIMESTEP_EMBEDDING:
3763
4177
  ggml_sycl_op_timestep_embedding(ctx, dst);
3764
4178
  break;
@@ -3771,6 +4185,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3771
4185
  case GGML_OP_GATED_LINEAR_ATTN:
3772
4186
  ggml_sycl_op_gated_linear_attn(ctx, dst);
3773
4187
  break;
4188
+ case GGML_OP_GATED_DELTA_NET:
4189
+ ggml_sycl_gated_delta_net(ctx, dst);
4190
+ break;
4191
+ case GGML_OP_SSM_CONV:
4192
+ ggml_sycl_ssm_conv(ctx, dst);
4193
+ break;
4194
+ case GGML_OP_ROLL:
4195
+ ggml_sycl_roll(ctx, dst);
4196
+ break;
4197
+ case GGML_OP_ARANGE:
4198
+ ggml_sycl_arange(ctx, dst);
4199
+ break;
4200
+ case GGML_OP_FLASH_ATTN_EXT:
4201
+ ggml_sycl_flash_attn_ext(ctx, dst);
4202
+ break;
3774
4203
  default:
3775
4204
  return false;
3776
4205
  }
@@ -3778,6 +4207,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3778
4207
  return true;
3779
4208
  } catch (sycl::exception & e) {
3780
4209
  std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
4210
+ std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
3781
4211
  std::exit(1);
3782
4212
  }
3783
4213
 
@@ -3800,16 +4230,6 @@ void ggml_backend_sycl_get_device_memory(int device, size_t *free,
3800
4230
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
3801
4231
  ggml_sycl_set_device(device);
3802
4232
 
3803
- /*
3804
- DPCT1009:218: SYCL uses exceptions to report errors and does not use the
3805
- error codes. The original code was commented out and a warning string was
3806
- inserted. You need to rewrite this code.
3807
- */
3808
- /*
3809
- DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
3810
- device information which may not be supported by all compilers or runtimes.
3811
- You may need to adjust the code.
3812
- */
3813
4233
  SYCL_CHECK(CHECK_TRY_ERROR(
3814
4234
  dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
3815
4235
  }
@@ -3931,6 +4351,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
3931
4351
  if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
3932
4352
  continue;
3933
4353
  }
4354
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
4355
+ continue;
4356
+ }
3934
4357
  #ifndef NDEBUG
3935
4358
  assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
3936
4359
  for (int j = 0; j < GGML_MAX_SRC; j++) {
@@ -3972,6 +4395,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
3972
4395
  GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
3973
4396
  ggml_op_name(node_op));
3974
4397
  return false;
4398
+ case GGML_OP_MUL_MAT:
4399
+ // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
4400
+ // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
4401
+ // in reordering.
4402
+ if (!g_ggml_sycl_use_async_mem_op) {
4403
+ GGML_LOG_INFO(
4404
+ "%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
4405
+ "oneAPI async memory allocation extension "
4406
+ "%s\n",
4407
+ __func__, ggml_op_name(node_op));
4408
+ return false;
4409
+ }
3975
4410
  }
3976
4411
  }
3977
4412
  return true;
@@ -4096,6 +4531,7 @@ struct ggml_backend_sycl_device_context {
4096
4531
  int device;
4097
4532
  std::string name;
4098
4533
  std::string description;
4534
+ int op_offload_min_batch_size;
4099
4535
  };
4100
4536
 
4101
4537
  static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
@@ -4166,6 +4602,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
4166
4602
  }
4167
4603
 
4168
4604
  static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4605
+ ggml_backend_sycl_device_context *sycl_ctx =
4606
+ (ggml_backend_sycl_device_context *)dev->context;
4607
+ int device = sycl_ctx->device;
4169
4608
  switch (op->op) {
4170
4609
  case GGML_OP_CONV_TRANSPOSE_1D:
4171
4610
  {
@@ -4178,21 +4617,27 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4178
4617
  }
4179
4618
  case GGML_OP_UNARY:
4180
4619
  switch (ggml_get_unary_op(op)) {
4620
+ case GGML_UNARY_OP_SGN:
4621
+ case GGML_UNARY_OP_ABS:
4181
4622
  case GGML_UNARY_OP_NEG:
4182
4623
  case GGML_UNARY_OP_STEP:
4624
+ case GGML_UNARY_OP_RELU:
4625
+ case GGML_UNARY_OP_HARDSIGMOID:
4626
+ case GGML_UNARY_OP_TANH:
4183
4627
  case GGML_UNARY_OP_GELU:
4184
4628
  case GGML_UNARY_OP_SILU:
4185
- case GGML_UNARY_OP_RELU:
4186
4629
  case GGML_UNARY_OP_SIGMOID:
4187
- case GGML_UNARY_OP_HARDSIGMOID:
4188
4630
  case GGML_UNARY_OP_HARDSWISH:
4189
4631
  case GGML_UNARY_OP_GELU_QUICK:
4190
4632
  case GGML_UNARY_OP_GELU_ERF:
4191
- case GGML_UNARY_OP_TANH:
4192
4633
  case GGML_UNARY_OP_EXP:
4193
- case GGML_UNARY_OP_SGN:
4194
- case GGML_UNARY_OP_ABS:
4634
+ case GGML_UNARY_OP_SOFTPLUS:
4195
4635
  case GGML_UNARY_OP_ELU:
4636
+ case GGML_UNARY_OP_CEIL:
4637
+ return true;
4638
+ case GGML_UNARY_OP_FLOOR:
4639
+ case GGML_UNARY_OP_ROUND:
4640
+ case GGML_UNARY_OP_TRUNC:
4196
4641
  #if defined (GGML_SYCL_F16)
4197
4642
  return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4198
4643
  #else
@@ -4206,6 +4651,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4206
4651
  case GGML_GLU_OP_REGLU:
4207
4652
  case GGML_GLU_OP_GEGLU:
4208
4653
  case GGML_GLU_OP_SWIGLU:
4654
+ case GGML_GLU_OP_SWIGLU_OAI:
4209
4655
  case GGML_GLU_OP_GEGLU_ERF:
4210
4656
  case GGML_GLU_OP_GEGLU_QUICK:
4211
4657
  return ggml_is_contiguous_1(op->src[0]);
@@ -4233,15 +4679,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4233
4679
  }
4234
4680
  }
4235
4681
  ggml_type src0_type = op->src[0]->type;
4236
- if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
4237
- // TODO: support MXFP4
4682
+ if (src0_type == GGML_TYPE_BF16 ) {
4683
+ // TODO: support GGML_TYPE_BF16
4238
4684
  // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4239
4685
  return false;
4240
4686
  }
4687
+
4241
4688
  // TODO: The configuration below needs more work to be supported with oneDNN
4242
- if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
4243
- return false;
4689
+ if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
4690
+ a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
4691
+ return false;
4244
4692
  }
4693
+
4245
4694
  // TODO: This specific configuration can fail with oneDNN and needs more debugging
4246
4695
  if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
4247
4696
  a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
@@ -4266,6 +4715,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4266
4715
  return false;
4267
4716
  }
4268
4717
  }
4718
+ case GGML_OP_SET:
4719
+ return (op->type == GGML_TYPE_F32) &&
4720
+ (op->src[0] && op->src[1]) &&
4721
+ (op->src[0]->type == GGML_TYPE_F32) &&
4722
+ (op->src[1]->type == GGML_TYPE_F32);
4723
+
4269
4724
  case GGML_OP_SET_ROWS:
4270
4725
  {
4271
4726
  return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
@@ -4343,11 +4798,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4343
4798
  }
4344
4799
  return false;
4345
4800
  }
4346
- case GGML_OP_CONCAT:
4801
+ case GGML_OP_REPEAT_BACK:
4347
4802
  {
4348
4803
  ggml_type src0_type = op->src[0]->type;
4349
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4804
+ return src0_type == GGML_TYPE_F32;
4350
4805
  }
4806
+ case GGML_OP_CONCAT:
4351
4807
  case GGML_OP_DUP:
4352
4808
  case GGML_OP_ARGMAX:
4353
4809
  case GGML_OP_NONE:
@@ -4355,15 +4811,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4355
4811
  case GGML_OP_VIEW:
4356
4812
  case GGML_OP_PERMUTE:
4357
4813
  case GGML_OP_TRANSPOSE:
4358
- return true;
4359
4814
  case GGML_OP_ADD:
4360
4815
  case GGML_OP_ADD1:
4816
+ case GGML_OP_ADD_ID:
4361
4817
  case GGML_OP_SUB:
4362
4818
  case GGML_OP_COUNT_EQUAL:
4363
4819
  case GGML_OP_MUL:
4364
4820
  case GGML_OP_DIV:
4365
4821
  case GGML_OP_REPEAT:
4366
4822
  return true;
4823
+ case GGML_OP_PAD_REFLECT_1D:
4824
+ return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
4367
4825
  case GGML_OP_SQR:
4368
4826
  case GGML_OP_SQRT:
4369
4827
  case GGML_OP_SIN:
@@ -4376,50 +4834,81 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4376
4834
  return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4377
4835
  #endif
4378
4836
  case GGML_OP_NORM:
4379
- return true;
4380
4837
  case GGML_OP_L2_NORM:
4381
4838
  case GGML_OP_GROUP_NORM:
4382
- return ggml_is_contiguous(op->src[0]);
4383
4839
  case GGML_OP_RMS_NORM:
4384
- return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
4840
+ return true;
4841
+ case GGML_OP_RMS_NORM_BACK:
4842
+ return ggml_is_contiguous(op->src[0]);
4385
4843
  case GGML_OP_SCALE:
4386
4844
  return true;
4387
4845
  case GGML_OP_CONT:
4388
4846
  return op->src[0]->type != GGML_TYPE_BF16;
4389
- case GGML_OP_SOFT_MAX:
4390
- // TODO: support batching
4391
- if (op->src[0]->ne[3] != 1) {
4392
- return false;
4393
- }
4394
- // TODO: support attention sinks [TAG_ATTN_SINKS]
4395
- if (op->src[2]) {
4396
- return false;
4847
+ case GGML_OP_TRI:
4848
+ {
4849
+ const ggml_tensor * src0 = op->src[0];
4850
+ return src0 &&
4851
+ op->type == GGML_TYPE_F32 &&
4852
+ ggml_is_contiguous(src0);
4397
4853
  }
4398
- // TODO: support broadcast
4399
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
4400
- return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
4401
4854
  case GGML_OP_DIAG_MASK_INF:
4855
+ return true;
4856
+ case GGML_OP_SOFT_MAX:
4857
+ return true;
4858
+ case GGML_OP_SOFT_MAX_BACK: {
4859
+ float max_bias = 0.0f;
4860
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
4861
+ return max_bias == 0.0f;
4862
+ }
4402
4863
  case GGML_OP_ROPE:
4864
+ case GGML_OP_ROPE_BACK:
4403
4865
  case GGML_OP_IM2COL:
4404
4866
  return true;
4405
4867
  case GGML_OP_UPSCALE:
4406
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4868
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
4407
4869
  case GGML_OP_SUM:
4408
4870
  case GGML_OP_SUM_ROWS:
4409
- case GGML_OP_ARGSORT:
4871
+ case GGML_OP_MEAN:
4410
4872
  return ggml_is_contiguous(op->src[0]);
4873
+ case GGML_OP_ARGSORT:
4874
+ return op->src[0]->ne[0] * sizeof(int) <=
4875
+ ggml_sycl_info().devices[device].smpbo;
4876
+ case GGML_OP_TOP_K: {
4877
+ const ggml_tensor * src0 = op->src[0];
4878
+ const int k = op->ne[0];
4879
+ return src0 &&
4880
+ op->type == GGML_TYPE_I32 &&
4881
+ src0->type == GGML_TYPE_F32 &&
4882
+ ggml_is_contiguous(src0) &&
4883
+ k > 0 && k <= 32;
4884
+ }
4411
4885
  case GGML_OP_POOL_2D:
4412
- case GGML_OP_ACC:
4413
4886
  return true;
4887
+ case GGML_OP_ACC:
4888
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
4414
4889
  case GGML_OP_PAD:
4415
- return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
4416
- (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
4890
+ // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
4891
+ if (ggml_get_op_params_i32(op, 8) != 0) {
4892
+ return false;
4893
+ }
4894
+ return ggml_is_contiguous(op->src[0]);
4417
4895
  case GGML_OP_LEAKY_RELU:
4418
4896
  case GGML_OP_TIMESTEP_EMBEDDING:
4419
4897
  case GGML_OP_RWKV_WKV6:
4420
4898
  case GGML_OP_RWKV_WKV7:
4421
4899
  case GGML_OP_GATED_LINEAR_ATTN:
4900
+ case GGML_OP_GATED_DELTA_NET:
4422
4901
  return true;
4902
+ case GGML_OP_SSM_CONV:
4903
+ return op->type == GGML_TYPE_F32 &&
4904
+ op->src[0]->type == GGML_TYPE_F32 &&
4905
+ op->src[1]->type == GGML_TYPE_F32;
4906
+ case GGML_OP_ROLL:
4907
+ return op->type == GGML_TYPE_F32;
4908
+ case GGML_OP_ARANGE:
4909
+ return op->type == GGML_TYPE_F32;
4910
+ case GGML_OP_FLASH_ATTN_EXT:
4911
+ return ggml_sycl_flash_attn_ext_supported(device, op);
4423
4912
  default:
4424
4913
  return false;
4425
4914
  }
@@ -4451,9 +4940,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
4451
4940
  }
4452
4941
 
4453
4942
  static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4454
- const int min_batch_size = 32;
4455
- return get_op_batch_size(op) >= min_batch_size;
4456
- GGML_UNUSED(dev);
4943
+ ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
4944
+ return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
4457
4945
  }
4458
4946
 
4459
4947
  static ggml_backend_event_t
@@ -4576,6 +5064,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
4576
5064
  std::lock_guard<std::mutex> lock(mutex);
4577
5065
  if (!initialized) {
4578
5066
  ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
5067
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
4579
5068
 
4580
5069
  for (int i = 0; i < ggml_sycl_info().device_count; i++) {
4581
5070
  ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
@@ -4589,6 +5078,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
4589
5078
  prop, dpct::dev_mgr::instance().get_device(i))));
4590
5079
 
4591
5080
  dev_ctx->description = prop.get_name();
5081
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
4592
5082
 
4593
5083
  ggml_backend_dev_t dev = new ggml_backend_device {
4594
5084
  /* .iface = */ ggml_backend_sycl_device_interface,