whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -30,19 +30,29 @@
30
30
  #include <regex>
31
31
 
32
32
  #include <sycl/sycl.hpp>
33
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
34
+ # include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
35
+ #endif
33
36
  #include <sycl/half_type.hpp>
34
37
 
35
38
  #include "ggml-sycl.h"
36
39
  #include "ggml-impl.h"
37
40
  #include "ggml-backend-impl.h"
38
41
 
42
+ #include "ggml-sycl/add-id.hpp"
39
43
  #include "ggml-sycl/backend.hpp"
40
44
  #include "ggml-sycl/common.hpp"
41
45
  #include "ggml-sycl/element_wise.hpp"
46
+ #include "ggml-sycl/norm.hpp"
42
47
  #include "ggml-sycl/presets.hpp"
43
48
  #include "ggml-sycl/gemm.hpp"
49
+ #include "ggml-sycl/set_rows.hpp"
50
+ #include "ggml-sycl/set.hpp"
44
51
  #include "ggml-sycl/sycl_hw.hpp"
45
52
  #include "ggml-sycl/getrows.hpp"
53
+ #include "ggml-sycl/repeat_back.hpp"
54
+ #include "ggml-sycl/quantize.hpp"
55
+ #include "ggml-sycl/ssm_conv.hpp"
46
56
  #include "ggml.h"
47
57
 
48
58
  static bool g_sycl_loaded = false;
@@ -51,6 +61,7 @@ int g_ggml_sycl_disable_optimize = 0;
51
61
  int g_ggml_sycl_disable_graph = 0;
52
62
  int g_ggml_sycl_disable_dnn = 0;
53
63
  int g_ggml_sycl_prioritize_dmmv = 0;
64
+ int g_ggml_sycl_use_async_mem_op = 0;
54
65
 
55
66
  static ggml_sycl_device_info ggml_sycl_init() {
56
67
  ggml_sycl_device_info info = {};
@@ -83,7 +94,10 @@ static ggml_sycl_device_info ggml_sycl_init() {
83
94
 
84
95
  info.devices[i].cc =
85
96
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
86
- info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
97
+ info.devices[i].nsm = prop.get_max_compute_units();
98
+ info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
99
+ info.devices[i].smpbo = prop.get_local_mem_size();
100
+
87
101
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
88
102
  }
89
103
 
@@ -231,7 +245,20 @@ static void ggml_check_sycl() try {
231
245
  fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
232
246
  #endif
233
247
  */
234
-
248
+ // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
249
+ // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
250
+ // other places.
251
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
252
+ g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
253
+ if (g_ggml_sycl_use_async_mem_op) {
254
+ for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
255
+ if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
256
+ g_ggml_sycl_use_async_mem_op = 0;
257
+ break;
258
+ }
259
+ }
260
+ }
261
+ #endif
235
262
  if (CHECK_TRY_ERROR(g_all_sycl_device_count =
236
263
  dpct::dev_mgr::instance().device_count()) != 0) {
237
264
  initialized = true;
@@ -1372,120 +1399,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
1372
1399
 
1373
1400
 
1374
1401
 
1375
- template<int QUANT_BLOCK_TILE>
1376
- static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
1377
- const sycl::nd_item<3> &item_ct1) {
1378
- const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1379
- item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
1380
-
1381
- if (ix >= kx_padded) {
1382
- return;
1383
- }
1384
-
1385
- const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1386
- item_ct1.get_local_id(1);
1387
-
1388
- const int i_padded = iy*kx_padded + ix;
1389
-
1390
- block_q8_1 * y = (block_q8_1 *) vy;
1391
-
1392
- const int ib = i_padded / QK8_1; // block index
1393
- const int iqs = i_padded % QK8_1; // quant index
1394
- typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
1395
- typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
1396
- TC zeros;
1397
- TQ qzeros;
1398
- #pragma unroll
1399
- for (int i = 0; i < QUANT_BLOCK_TILE; i++)
1400
- {
1401
- zeros[i] = 0.f;
1402
- qzeros[i] = 0;
1403
- }
1404
- const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
1405
- float sum = xi[0];
1406
- float amax = sycl::fabs(xi[0]);
1407
- #pragma unroll
1408
- for (int i = 1; i < QUANT_BLOCK_TILE; i++)
1409
- {
1410
- sum += xi[i];
1411
- amax = sycl::fmax(sycl::fabs(xi[i]), amax);
1412
- }
1413
- sum = warp_reduce_sum(sum, item_ct1);
1414
- amax = warp_reduce_max(amax, item_ct1);
1415
-
1416
- const float d = amax / 127;
1417
- TQ q = qzeros;
1418
- if (amax != 0.0f)
1419
- {
1420
- #pragma unroll
1421
- for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
1422
- q[i] = sycl::round(xi[i] / d);
1423
- }
1424
- }
1425
-
1426
- *(TQ *)&y[ib].qs[iqs] = q;
1427
-
1428
- if (iqs > 0) {
1429
- return;
1430
- }
1431
-
1432
- reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
1433
- reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1434
- }
1435
-
1436
- template <int ElementsPerWI>
1437
- static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
1438
- const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
1439
- /*
1440
- Quantizes and reorders the resultant q8 tensor in a per row fashion
1441
- Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1442
- */
1443
-
1444
- auto subgroup_id = it.get_group(0);
1445
- auto wi_id = it.get_local_id(0);
1446
-
1447
- const int num_blocks_per_row = kx / QK8_1;
1448
- auto row = subgroup_id / num_blocks_per_row;
1449
- auto col = subgroup_id % num_blocks_per_row;
1450
-
1451
- auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
1452
- auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1453
-
1454
- auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1455
- auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
1456
-
1457
- sycl::vec<float, ElementsPerWI> wi_f32_vals;
1458
- sycl::vec<int8_t, ElementsPerWI> quantized_values;
1459
-
1460
- auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1461
- wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
1462
-
1463
- float sum = 0.0f;
1464
- float amax = 0.0f;
1465
-
1466
- #pragma unroll(ElementsPerWI)
1467
- for (int i = 0; i < ElementsPerWI; i++) {
1468
- sum += wi_f32_vals[i];
1469
- amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
1470
- quantized_values[i] = 0;
1471
- }
1472
- sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
1473
- amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
1474
- float d = amax == 0 ? 1 : amax / 127;
1475
-
1476
- #pragma unroll(ElementsPerWI)
1477
- for (int i = 0; i < ElementsPerWI; i++) {
1478
- quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
1479
- }
1480
-
1481
- d = amax == 0 ? 0 : d;
1482
-
1483
- *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
1484
- if (wi_id == 0) {
1485
- *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
1486
- }
1487
- }
1488
-
1489
1402
  static void mul_mat_p021_f16_f32(
1490
1403
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1491
1404
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1545,7 +1458,7 @@ static void mul_mat_p021_f16_f32(
1545
1458
 
1546
1459
  static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1547
1460
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1548
- const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
1461
+ const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
1549
1462
  const sycl::nd_item<3> &item_ct1) {
1550
1463
 
1551
1464
  const sycl::half *x = (const sycl::half *)vx;
@@ -1556,7 +1469,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1556
1469
  item_ct1.get_local_id(0);
1557
1470
  const int channel_x = channel / channel_x_divisor;
1558
1471
 
1559
- const int nrows_y = ncols_x;
1560
1472
  const int nrows_dst = nrows_x;
1561
1473
  const int row_dst = row_x;
1562
1474
 
@@ -1575,7 +1487,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1575
1487
  const int row_y = col_x;
1576
1488
 
1577
1489
  const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1578
- const int iy = channel*nrows_y + row_y;
1490
+ const int iy = channel * channel_stride_y + row_y;
1579
1491
 
1580
1492
  const float xi =
1581
1493
  sycl::vec<sycl::half, 1>(x[ix])
@@ -1624,60 +1536,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
1624
1536
  template <ggml_sort_order order>
1625
1537
  __dpct_inline__ static void
1626
1538
  k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
1627
- const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
1539
+ const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
1540
+ uint8_t *dpct_local) {
1628
1541
  // bitonic sort
1629
- int col = item_ct1.get_local_id(2);
1542
+ int col_index = item_ct1.get_local_id(2);
1630
1543
  int row = item_ct1.get_group(1);
1631
1544
 
1632
- if (col >= ncols_pad) {
1633
- return;
1545
+ for (int i = 0; i < tasks_per_thread; i++) {
1546
+ int col = col_index * tasks_per_thread + i;
1547
+ if (col >= ncols_pad) {
1548
+ return;
1549
+ }
1634
1550
  }
1635
1551
 
1636
1552
  const float * x_row = x + row * ncols;
1637
1553
  auto dst_row = (int *)dpct_local;
1638
1554
 
1639
1555
  // initialize indices
1640
- dst_row[col] = col;
1556
+ for (int i=0;i<tasks_per_thread;i++){
1557
+ int col = col_index*tasks_per_thread+i;
1558
+ dst_row[col] = col;
1559
+ }
1641
1560
 
1642
1561
  item_ct1.barrier(sycl::access::fence_space::local_space);
1643
1562
 
1644
1563
  for (int k = 2; k <= ncols_pad; k *= 2) {
1645
1564
  for (int j = k / 2; j > 0; j /= 2) {
1646
- int ixj = col ^ j;
1647
- if (ixj > col) {
1648
- if ((col & k) == 0) {
1649
- if (dst_row[col] >= ncols ||
1650
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
1651
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
1652
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
1653
- ) {
1654
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1655
- }
1656
- } else {
1657
- if (dst_row[ixj] >= ncols ||
1658
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
1659
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
1660
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
1661
- ) {
1662
- ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1565
+ for (int i = 0; i < tasks_per_thread; i++) {
1566
+ int col = col_index * tasks_per_thread + i;
1567
+ int ixj = col ^ j;
1568
+ if (ixj > col) {
1569
+ if ((col & k) == 0) {
1570
+ if (dst_row[col] >= ncols ||
1571
+ (dst_row[ixj] < ncols &&
1572
+ (order == GGML_SORT_ORDER_ASC
1573
+ ? x_row[dst_row[col]] > x_row[dst_row[ixj]]
1574
+ : x_row[dst_row[col]] <
1575
+ x_row[dst_row[ixj]]))) {
1576
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1577
+ }
1578
+ } else {
1579
+ if (dst_row[ixj] >= ncols ||
1580
+ (dst_row[col] < ncols &&
1581
+ (order == GGML_SORT_ORDER_ASC
1582
+ ? x_row[dst_row[col]] < x_row[dst_row[ixj]]
1583
+ : x_row[dst_row[col]] >
1584
+ x_row[dst_row[ixj]]))) {
1585
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1586
+ }
1663
1587
  }
1664
1588
  }
1589
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1665
1590
  }
1666
- /*
1667
- DPCT1118:1: SYCL group functions and algorithms must be encountered
1668
- in converged control flow. You may need to adjust the code.
1669
- */
1670
- item_ct1.barrier(sycl::access::fence_space::local_space);
1671
1591
  }
1672
1592
  }
1673
1593
 
1674
1594
  // copy the result to dst without the padding
1675
- if (col < ncols) {
1676
- dst[row * ncols + col] = dst_row[col];
1595
+ for (int i = 0; i < tasks_per_thread; i++) {
1596
+ int col = col_index * tasks_per_thread + i;
1597
+ if (col < ncols) {
1598
+ dst[row * ncols + col] = dst_row[col];
1599
+ }
1677
1600
  }
1678
1601
  }
1679
1602
 
1680
-
1681
1603
  static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
1682
1604
  const sycl::nd_item<3> &item_ct1) {
1683
1605
  const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -1695,7 +1617,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
1695
1617
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
1696
1618
  }
1697
1619
 
1698
- static void scale_f32(const float * x, float * dst, const float scale, const int k,
1620
+ static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
1699
1621
  const sycl::nd_item<3> &item_ct1) {
1700
1622
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1701
1623
  item_ct1.get_local_id(2);
@@ -1704,7 +1626,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
1704
1626
  return;
1705
1627
  }
1706
1628
 
1707
- dst[i] = scale * x[i];
1629
+ dst[i] = scale * x[i] + bias;
1708
1630
  }
1709
1631
 
1710
1632
 
@@ -1770,32 +1692,6 @@ static void pool2d_nchw_kernel(
1770
1692
  o_ptr[cur_oh * ow + cur_ow] = res;
1771
1693
  }
1772
1694
 
1773
- static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1774
- bool reorder_q8_tensor, queue_ptr stream) {
1775
- if (reorder_q8_tensor) {
1776
- auto local_range = std::size_t(WARP_SIZE);
1777
- auto num_quant_blocks = ky * (kx / QK8_1);
1778
- auto global_range = num_quant_blocks * local_range;
1779
- stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
1780
- [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1781
- quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1782
- });
1783
- } else {
1784
- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1785
- const sycl::range<3> num_blocks(1, ky, block_num_x);
1786
- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1787
- static_assert(QK8_1 % WARP_SIZE == 0);
1788
- const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1789
- {
1790
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
1791
-
1792
- stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
1793
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1794
- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1795
- });
1796
- }
1797
- }
1798
- }
1799
1695
 
1800
1696
  static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1801
1697
  float *dst, const int ncols_x,
@@ -1822,7 +1718,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1822
1718
  static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1823
1719
  const void *vx, const float *y, float *dst, const int ncols_x,
1824
1720
  const int nrows_x, const int row_stride_x, const int nchannels_x,
1825
- const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
1721
+ const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
1826
1722
 
1827
1723
  const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1828
1724
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -1834,7 +1730,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1834
1730
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1835
1731
  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1836
1732
  mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1837
- row_stride_x, channel_stride_x,
1733
+ row_stride_x, channel_stride_x, channel_stride_y,
1838
1734
  nchannels_y / nchannels_x, item_ct1);
1839
1735
  });
1840
1736
  }
@@ -1842,7 +1738,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1842
1738
 
1843
1739
 
1844
1740
 
1845
- static void scale_f32_sycl(const float *x, float *dst, const float scale,
1741
+ static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
1846
1742
  const int k, queue_ptr stream) {
1847
1743
  const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
1848
1744
  stream->parallel_for(
@@ -1850,7 +1746,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
1850
1746
  sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
1851
1747
  sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
1852
1748
  [=](sycl::nd_item<3> item_ct1) {
1853
- scale_f32(x, dst, scale, k, item_ct1);
1749
+ scale_f32(x, dst, scale, bias, k, item_ct1);
1854
1750
  });
1855
1751
  }
1856
1752
 
@@ -1876,37 +1772,51 @@ static int next_power_of_2(int x) {
1876
1772
 
1877
1773
  static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1878
1774
  const int nrows, ggml_sort_order order,
1879
- queue_ptr stream) {
1775
+ queue_ptr stream, int device) {
1880
1776
  // bitonic sort requires ncols to be power of 2
1881
1777
  const int ncols_pad = next_power_of_2(ncols);
1882
1778
 
1883
- const sycl::range<3> block_dims(1, 1, ncols_pad);
1779
+ int nth = 1;
1780
+ int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
1781
+ while (nth < ncols_pad && nth < max_block_size)
1782
+ nth *= 2;
1783
+ if (nth > max_block_size)
1784
+ nth = max_block_size;
1785
+
1786
+ const int tasks_per_thread = ncols_pad / nth;
1787
+
1788
+ const sycl::range<3> block_dims(1, 1, nth);
1884
1789
  const sycl::range<3> block_nums(1, nrows, 1);
1885
1790
  const size_t shared_mem = ncols_pad * sizeof(int);
1791
+ GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
1886
1792
 
1887
1793
  if (order == GGML_SORT_ORDER_ASC) {
1888
- sycl_launch(stream, [&](sycl::handler & cgh) {
1794
+ stream->submit([&](sycl::handler &cgh) {
1889
1795
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1890
1796
  sycl::range<1>(shared_mem), cgh);
1891
1797
 
1892
- sycl_parallel_for(
1893
- cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1798
+ cgh.parallel_for(
1799
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1800
+ [=](sycl::nd_item<3> item_ct1) {
1894
1801
  k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
1895
- x, dst, ncols, ncols_pad, item_ct1,
1896
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
1802
+ x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1803
+ dpct_local_acc_ct1
1804
+ .get_multi_ptr<sycl::access::decorated::no>()
1897
1805
  .get());
1898
1806
  });
1899
1807
  });
1900
1808
  } else if (order == GGML_SORT_ORDER_DESC) {
1901
- sycl_launch(stream, [&](sycl::handler & cgh) {
1809
+ stream->submit([&](sycl::handler &cgh) {
1902
1810
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1903
1811
  sycl::range<1>(shared_mem), cgh);
1904
1812
 
1905
- sycl_parallel_for(
1906
- cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1813
+ cgh.parallel_for(
1814
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1815
+ [=](sycl::nd_item<3> item_ct1) {
1907
1816
  k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
1908
- x, dst, ncols, ncols_pad, item_ct1,
1909
- dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
1817
+ x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
1818
+ dpct_local_acc_ct1
1819
+ .get_multi_ptr<sycl::access::decorated::no>()
1910
1820
  .get());
1911
1821
  });
1912
1822
  });
@@ -1921,47 +1831,50 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
1921
1831
  const sycl::range<3> block_nums(1, nrows, 1);
1922
1832
  const size_t shared_mem = 256 * sizeof(float);
1923
1833
 
1924
- sycl_launch(stream, [&](sycl::handler & cgh) {
1834
+ stream->submit([&](sycl::handler &cgh) {
1925
1835
  sycl::local_accessor<float, 1> shared_data(
1926
1836
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
1927
1837
  sycl::local_accessor<int, 1> shared_indices(
1928
1838
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
1929
1839
 
1930
- sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1931
- const int tid = item_ct1.get_local_id(2);
1932
- const int row = item_ct1.get_global_id(1);
1933
-
1934
- float max_val = -INFINITY;
1935
- int max_idx = -1;
1936
-
1937
- for (int col = tid; col < ncols; col += 256) {
1938
- float val = x[row * ncols + col];
1939
- if (val > max_val) {
1940
- max_val = val;
1941
- max_idx = col;
1840
+ cgh.parallel_for(
1841
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1842
+ [=](sycl::nd_item<3> item_ct1) {
1843
+ const int tid = item_ct1.get_local_id(2);
1844
+ const int row = item_ct1.get_global_id(1);
1845
+
1846
+ float max_val = -INFINITY;
1847
+ int max_idx = -1;
1848
+
1849
+ for (int col = tid; col < ncols; col += 256) {
1850
+ float val = x[row * ncols + col];
1851
+ if (val > max_val) {
1852
+ max_val = val;
1853
+ max_idx = col;
1854
+ }
1942
1855
  }
1943
- }
1944
1856
 
1945
- shared_data[tid] = max_val;
1946
- shared_indices[tid] = max_idx;
1947
- item_ct1.barrier(sycl::access::fence_space::local_space);
1948
-
1949
- for (int stride = 256 / 2; stride > 0; stride >>= 1) {
1950
- if (tid < stride) {
1951
- float val1 = shared_data[tid];
1952
- float val2 = shared_data[tid + stride];
1953
- if (val2 > val1) {
1954
- shared_data[tid] = val2;
1955
- shared_indices[tid] = shared_indices[tid + stride];
1857
+ shared_data[tid] = max_val;
1858
+ shared_indices[tid] = max_idx;
1859
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1860
+
1861
+ for (int stride = 256/2; stride > 0; stride >>= 1) {
1862
+ if (tid < stride) {
1863
+ float val1 = shared_data[tid];
1864
+ float val2 = shared_data[tid + stride];
1865
+ if (val2 > val1) {
1866
+ shared_data[tid] = val2;
1867
+ shared_indices[tid] = shared_indices[tid + stride];
1868
+ }
1956
1869
  }
1870
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1957
1871
  }
1958
- item_ct1.barrier(sycl::access::fence_space::local_space);
1959
- }
1960
1872
 
1961
- if (tid == 0) {
1962
- dst[row] = shared_indices[0];
1963
- }
1964
- });
1873
+
1874
+ if (tid == 0) {
1875
+ dst[row] = shared_indices[0];
1876
+ }
1877
+ });
1965
1878
  });
1966
1879
  }
1967
1880
  static void diag_mask_inf_f32_sycl(const float *x, float *dst,
@@ -2123,8 +2036,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2123
2036
 
2124
2037
  #if GGML_SYCL_DNNL
2125
2038
  if (!g_ggml_sycl_disable_dnn) {
2126
- DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2127
- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2039
+ DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
2040
+ DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2128
2041
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2129
2042
  }
2130
2043
  else
@@ -2170,8 +2083,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2170
2083
 
2171
2084
  #if GGML_SYCL_DNNL
2172
2085
  if (!g_ggml_sycl_disable_dnn) {
2173
- DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2174
- DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2086
+ DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
2087
+ DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2175
2088
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2176
2089
  }
2177
2090
  else
@@ -2261,6 +2174,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
2261
2174
  sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2262
2175
  }
2263
2176
 
2177
+ inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2178
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2179
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2180
+
2181
+ dpct::queue_ptr main_stream = ctx.stream();
2182
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2183
+
2184
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2185
+ float * dst_dd = static_cast<float *>(dst->data);
2186
+
2187
+ const int64_t ncols = dst->src[0]->ne[0];
2188
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2189
+
2190
+ sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2191
+
2192
+ main_stream->parallel_for(
2193
+ sycl::range<1>(nrows),
2194
+ [=](sycl::id<1> row) {
2195
+ dst_dd[row] /= ncols;
2196
+ }
2197
+ );
2198
+ }
2199
+
2200
+
2264
2201
  inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2265
2202
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2266
2203
  GGML_ASSERT(dst->type == GGML_TYPE_I32);
@@ -2275,7 +2212,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
2275
2212
 
2276
2213
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2277
2214
 
2278
- argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
2215
+ argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
2216
+ main_stream, ctx.device);
2279
2217
  }
2280
2218
 
2281
2219
  inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -2319,9 +2257,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
2319
2257
  float * dst_dd = static_cast<float *>(dst->data);
2320
2258
 
2321
2259
  float scale;
2322
- memcpy(&scale, dst->op_params, sizeof(float));
2260
+ float bias;
2261
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
2262
+ memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
2323
2263
 
2324
- scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
2264
+ scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
2325
2265
  /*
2326
2266
  DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2327
2267
  error codes. The call was replaced with 0. You need to rewrite this code.
@@ -2370,10 +2310,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
2370
2310
  peer_access_enabled = enable_peer_access;
2371
2311
  }
2372
2312
 
2313
+ template <template <int> typename quantize_f>
2373
2314
  static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2374
2315
  const ggml_tensor *src1, ggml_tensor *dst,
2375
- ggml_sycl_op_mul_mat_t op,
2376
- const bool convert_src1_to_q8_1) try {
2316
+ ggml_sycl_op_mul_mat_t op) try {
2377
2317
 
2378
2318
  GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
2379
2319
 
@@ -2468,6 +2408,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2468
2408
  }
2469
2409
  }
2470
2410
 
2411
+ constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
2412
+ no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
2471
2413
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2472
2414
  if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
2473
2415
  continue;
@@ -2493,20 +2435,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2493
2435
  dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
2494
2436
  }
2495
2437
 
2496
- if (convert_src1_to_q8_1) {
2438
+ if constexpr(quantize_enabled) {
2497
2439
  dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2498
2440
 
2499
2441
  if (src1_on_device && src1_is_contiguous) {
2500
- bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
2501
2442
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2502
2443
  /*num_src=*/2, " : converting src1 to Q8_1");
2503
- quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2504
- /*
2505
- DPCT1010:90: SYCL uses exceptions to report errors and does not
2506
- use the error codes. The call was replaced with 0. You need to
2507
- rewrite this code.
2508
- */
2509
- SYCL_CHECK(0);
2444
+ try {
2445
+ quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2446
+ } catch (sycl::exception const &exc) {
2447
+ std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
2448
+ << ", line:" << __LINE__ << std::endl;
2449
+ std::exit(1);
2450
+ }
2510
2451
  }
2511
2452
  }
2512
2453
 
@@ -2522,11 +2463,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2522
2463
  // here an event is recorded that signals that the main device has finished calculating the input data
2523
2464
  if (split && used_devices > 1) {
2524
2465
  ggml_sycl_set_device(ctx.device);
2525
- /*
2526
- DPCT1024:91: The original code returned the error code that was further
2527
- consumed by the program logic. This original code was replaced with 0.
2528
- You may need to rewrite the program logic consuming the error code.
2529
- */
2530
2466
  SYCL_CHECK(CHECK_TRY_ERROR(
2531
2467
  *src0_extra->events[ctx.device][0] =
2532
2468
  ctx.stream()->ext_oneapi_submit_barrier()));
@@ -2550,11 +2486,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2550
2486
 
2551
2487
  // wait for main GPU data if necessary
2552
2488
  if (split && (i != ctx.device || is != 0)) {
2553
- /*
2554
- DPCT1009:163: SYCL uses exceptions to report errors and does not
2555
- use the error codes. The original code was commented out and a
2556
- warning string was inserted. You need to rewrite this code.
2557
- */
2558
2489
  SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
2559
2490
  {*src0_extra->events[ctx.device][0]})));
2560
2491
  }
@@ -2580,39 +2511,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2580
2511
  // copy src0, src1 to device if necessary
2581
2512
  if (src1_is_contiguous) {
2582
2513
  if (i != ctx.device) {
2583
- if (convert_src1_to_q8_1) {
2514
+ if constexpr (quantize_enabled) {
2584
2515
  char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
2585
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
2586
- src1_ddq_i, src1_ddq_i_source,
2587
- src1_ncols * src1_padded_col_size * q8_1_ts /
2588
- q8_1_bs).wait()));
2516
+ SYCL_CHECK(
2517
+ CHECK_TRY_ERROR(stream
2518
+ ->memcpy(src1_ddq_i, src1_ddq_i_source,
2519
+ src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
2520
+ .wait()));
2589
2521
  } else {
2590
-
2591
2522
  float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
2592
- src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
2523
+ src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
2593
2524
 
2594
- SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
2595
- src1_ddf_i, src1_ddf_i_source,
2596
- src1_ncols * ne10 * sizeof(float))));
2525
+ SYCL_CHECK(
2526
+ CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
2527
+ src1_ncols * ne10 * sizeof(float))));
2597
2528
  }
2598
2529
  }
2599
- } else if (src1_on_device && !src1_is_contiguous) {
2600
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
2601
- src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
2602
2530
  } else {
2603
- GGML_ABORT("fatal error");
2604
- }
2531
+ if (src1_on_device) {
2532
+ SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
2533
+ src1_col_0 + src1_ncols, stream));
2534
+ } else {
2535
+ GGML_ABORT("src1 is non-contiguous and not on device");
2536
+ }
2605
2537
 
2606
- if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2607
- scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2608
- /*num_src=*/2, " : converting src1 to Q8_1");
2609
- quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
2610
- /*
2611
- DPCT1010:92: SYCL uses exceptions to report errors and does
2612
- not use the error codes. The call was replaced with 0. You
2613
- need to rewrite this code.
2614
- */
2615
- SYCL_CHECK(0);
2538
+ if constexpr (quantize_enabled) {
2539
+ scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2540
+ /*num_src=*/2, " : converting src1 to Q8_1");
2541
+ try {
2542
+ quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
2543
+ src1_padded_col_size, stream);
2544
+ } catch (const sycl::exception & exc) {
2545
+ std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
2546
+ << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2547
+ std::exit(1);
2548
+ }
2549
+ }
2616
2550
  }
2617
2551
 
2618
2552
  if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
@@ -2624,12 +2558,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2624
2558
  // do the computation
2625
2559
  SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
2626
2560
  dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
2627
- /*
2628
- DPCT1010:93: SYCL uses exceptions to report errors and does not
2629
- use the error codes. The call was replaced with 0. You need to
2630
- rewrite this code.
2631
- */
2632
- SYCL_CHECK(0);
2633
2561
 
2634
2562
  // copy dst to host or other device if necessary
2635
2563
  if (!dst_on_device) {
@@ -2660,12 +2588,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2660
2588
 
2661
2589
  // add event for the main device to wait on until other device is done
2662
2590
  if (split && (i != ctx.device || is != 0)) {
2663
- /*
2664
- DPCT1024:94: The original code returned the error code that
2665
- was further consumed by the program logic. This original
2666
- code was replaced with 0. You may need to rewrite the
2667
- program logic consuming the error code.
2668
- */
2669
2591
  SYCL_CHECK(CHECK_TRY_ERROR(
2670
2592
  *src0_extra->events[i][is] =
2671
2593
  stream->ext_oneapi_submit_barrier()));
@@ -2698,6 +2620,10 @@ catch (sycl::exception const &exc) {
2698
2620
  std::exit(1);
2699
2621
  }
2700
2622
 
2623
+ static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2624
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2625
+ ggml_sycl_op_repeat_back(ctx, dst);
2626
+ }
2701
2627
 
2702
2628
  static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2703
2629
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -2714,6 +2640,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
2714
2640
  ggml_sycl_op_rms_norm(ctx, dst);
2715
2641
  }
2716
2642
 
2643
+ static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2644
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2645
+ ggml_sycl_op_rms_norm_back(ctx, dst);
2646
+ }
2647
+
2717
2648
  static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2718
2649
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2719
2650
  ggml_sycl_op_l2_norm(ctx, dst);
@@ -2764,6 +2695,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2764
2695
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2765
2696
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2766
2697
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2698
+ GGML_ASSERT(src1->ne[1] == 1);
2699
+ GGML_ASSERT(src1->ne[3] == 1);
2767
2700
 
2768
2701
  const int64_t ne00 = src0->ne[0];
2769
2702
  const int64_t ne01 = src0->ne[1];
@@ -2773,6 +2706,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2773
2706
  const int64_t nb02 = src0->nb[2];
2774
2707
 
2775
2708
  const int64_t ne12 = src1->ne[2];
2709
+ const int64_t nb11 = src1->nb[1];
2776
2710
 
2777
2711
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2778
2712
  queue_ptr main_stream = ctx.stream();
@@ -2783,8 +2717,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2783
2717
 
2784
2718
  const int64_t row_stride_x = nb01 / sizeof(sycl::half);
2785
2719
  const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
2720
+ const int64_t channel_stride_y = nb11 / sizeof(float);
2786
2721
 
2787
- ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
2722
+ ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
2788
2723
  }
2789
2724
  catch (sycl::exception const &exc) {
2790
2725
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -2838,8 +2773,11 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2838
2773
  float * dst_ddf = static_cast<float *>(dst->data);
2839
2774
 
2840
2775
  const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2776
+ const size_t type_size_src0 = ggml_type_size(src0->type);
2841
2777
  const size_t type_size_src1 = ggml_type_size(src1->type);
2842
- GGML_ASSERT(nb10 == type_size_src1);
2778
+
2779
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
2780
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
2843
2781
 
2844
2782
  // SRC1 strides
2845
2783
  int64_t s11 = nb11 / type_size_src1;
@@ -2851,16 +2789,47 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2851
2789
  if (src1->type != GGML_TYPE_F16) {
2852
2790
  scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2853
2791
  " : converting src1 to fp16");
2854
- const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2855
- GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2792
+
2793
+ // iterate tensor dims and find the slowest moving dim and stride
2794
+ int last_dim=0;
2795
+ int last_str=0;
2796
+ size_t largest_str=0;
2797
+ for(int i = 0; i< 4; i++){
2798
+ // last stride is always the largest
2799
+ if(src1->nb[i] == largest_str){
2800
+ if(src1->ne[last_dim] == 1){
2801
+ last_str = i;
2802
+ last_dim = i;
2803
+ }
2804
+ }
2805
+ if(src1->nb[i] > largest_str){
2806
+ largest_str = src1->nb[i];
2807
+ last_str = i;
2808
+ last_dim = i;
2809
+ }
2810
+
2811
+ }
2812
+ #if GGML_SYCL_DNNL
2813
+ // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
2814
+ const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
2815
+ src1_f16_alloc.alloc(ne_src1);
2816
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2817
+ GGML_ASSERT(to_fp16_sycl != nullptr);
2818
+ to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
2819
+ # else
2856
2820
  const int64_t ne_src1 = ggml_nelements(src1);
2857
2821
  src1_f16_alloc.alloc(ne_src1);
2822
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2823
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2858
2824
  to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2825
+ #endif
2859
2826
 
2860
2827
  src1_f16 = src1_f16_alloc.get();
2861
2828
  s11 = ne10;
2862
2829
  s12 = ne11 * s11;
2863
2830
  s13 = ne12 * s12;
2831
+
2832
+ is_src1_cont_2 = true;
2864
2833
  }
2865
2834
 
2866
2835
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
@@ -2889,48 +2858,115 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2889
2858
 
2890
2859
  #if GGML_SYCL_DNNL
2891
2860
  if (!g_ggml_sycl_disable_dnn) {
2892
- auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2893
- (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2894
-
2895
- DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2896
- src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2897
- src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2898
- dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2899
- };
2900
-
2901
- if (r2 == 1 && r3 == 1) {
2902
- if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2903
- dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2904
- }
2905
- else {
2906
- for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2907
- const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2908
- const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2909
- float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2910
- dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2861
+ int64_t str_a0 = nb00 / type_size_src0;
2862
+ int64_t str_a1 = nb01 / type_size_src0;
2863
+ int64_t str_a2 = nb02 / type_size_src0;
2864
+
2865
+ int64_t str_b0 = nb10 / type_size_src1;
2866
+ int64_t str_b1 = nb11 / type_size_src1;
2867
+ int64_t str_b2 = nb12 / type_size_src1;
2868
+
2869
+ auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
2870
+ const sycl::half *src1, float *dst,
2871
+ int64_t a0, int64_t a1, int64_t batcha,
2872
+ int64_t /*b0*/, int64_t b1, int64_t batchb,
2873
+ int64_t sa0, int64_t sa1, int64_t sa2,
2874
+ int64_t sb0, int64_t sb1, int64_t sb2,
2875
+ int64_t sd2) {
2876
+ bool supported_broadcast = batchb == batcha ? true
2877
+ : batchb == 1 || batcha == 1 ? true
2878
+ : false;
2879
+ if (supported_broadcast) {
2880
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
2881
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
2882
+ DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
2883
+ DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
2884
+ } else {
2885
+ // iterate over batches from smaller set of matrices (matrix 0)
2886
+ int64_t batches0 = batcha;
2887
+ int64_t batches1 = batchb;
2888
+
2889
+ if (batches0 > batches1) {
2890
+ int64_t num_mul_mats = batches1;
2891
+ int64_t sub_batch = batches0 / num_mul_mats;
2892
+ // src0 is batched and bigger, shift and multiply with src1
2893
+ for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
2894
+ const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
2895
+ const sycl::half *src1_shifted = src1 + (sb2 * i0);
2896
+ float *dst_shifted = dst + (sd2 * i0 * sub_batch);
2897
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2898
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2899
+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2900
+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2901
+ queue, sub_batch, 1);
2902
+ }
2903
+ } else {
2904
+ int64_t num_mul_mats = batches0;
2905
+ int64_t sub_batch = batches1 / num_mul_mats;
2906
+ // src1 is batched and bigger, shift and multiply with src0
2907
+ for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
2908
+ const sycl::half *src0_shifted = src0 + (sa2 * i1);
2909
+ const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
2910
+ float *dst_shifted = dst + (sd2 * i1 * sub_batch);
2911
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2912
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2913
+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2914
+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2915
+ queue, 1, sub_batch);
2916
+ }
2917
+ }
2911
2918
  }
2912
- }
2913
- } else {
2914
- // iterate over batches from smaller set of matrices (matrix 0)
2915
- for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2916
- for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2917
- const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2918
- const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2919
- float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2920
- dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2919
+ };
2920
+
2921
+ const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
2922
+ const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
2923
+ const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
2924
+ const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
2925
+ if (cont_batches_dim2_a && cont_batches_dim2_b) {
2926
+ // A batch is considered contiguous if the dimension 2 is not strided
2927
+ int64_t batches0 = ne02 * ne03;
2928
+ int64_t batches1 = ne12 * ne13;
2929
+ launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2930
+ ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
2931
+ str_b2, nb2 / sizeof(float));
2932
+ } else if (cont_batches_dim3_a && cont_batches_dim3_b) {
2933
+ // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
2934
+ int64_t batches0 = ne02 * ne03;
2935
+ int64_t batches1 = ne12 * ne13;
2936
+ int64_t str_a3 = nb03 / type_size_src0;
2937
+ int64_t str_b3 = nb13 / type_size_src1;
2938
+ launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2939
+ ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
2940
+ str_b3, nb2 / sizeof(float));
2941
+ } else {
2942
+ for (int64_t b_a = 0; b_a < ne03; b_a++) {
2943
+ const sycl::half *src0_f16_shifted
2944
+ = src0_f16 + (nb03 * b_a / type_size_src0);
2945
+ const sycl::half *src1_f16_shifted
2946
+ = src1_f16 + (nb13 * b_a / type_size_src1);
2947
+ float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
2948
+ int64_t batches0 = ne02;
2949
+ int64_t batches1 = ne12;
2950
+ launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
2951
+ ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
2952
+ str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
2921
2953
  }
2922
2954
  }
2923
- }
2955
+
2924
2956
  }
2925
2957
  else
2926
2958
  #endif
2927
2959
  {
2928
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2960
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
2961
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
2962
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
2963
+ const int64_t smb = ne12 == 1 ? s13 : s12;
2964
+
2929
2965
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2930
2966
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2931
2967
  oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2932
- src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2933
- src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2968
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
2969
+ src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
2934
2970
  mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2935
2971
  } else {
2936
2972
  const int ne23 = ne12 * ne13;
@@ -2945,7 +2981,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2945
2981
  void ** ptrs_dst_get = ptrs_dst.get();
2946
2982
  size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2947
2983
  size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2948
- sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2984
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2949
2985
  k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2950
2986
  nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2951
2987
  });
@@ -3026,19 +3062,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3026
3062
  }
3027
3063
  }
3028
3064
 
3065
+ // Helper functions to unify device memory allocation for both async and sync paths
3066
+ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
3067
+ bool use_async = g_ggml_sycl_use_async_mem_op;
3068
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3069
+ if (use_async) {
3070
+ return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
3071
+ }
3072
+ #else
3073
+ // If async allocation extension is not available, use_async should always be false.
3074
+ GGML_ASSERT(!use_async);
3075
+ #endif
3076
+ return sycl::malloc(size, *stream, sycl::usm::alloc::device);
3077
+ }
3078
+
3079
+ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
3080
+ bool use_async = g_ggml_sycl_use_async_mem_op;
3081
+ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
3082
+ if (use_async) {
3083
+ syclex::async_free(*stream, ptr);
3084
+ return;
3085
+ }
3086
+ #else
3087
+ // If async allocation extension is not available, use_async should always be false.
3088
+ GGML_ASSERT(!use_async);
3089
+ #endif
3090
+ sycl::free(ptr, *stream);
3091
+ }
3092
+
3029
3093
  static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
3030
3094
  dpct::queue_ptr stream) {
3031
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3032
- SYCL_CHECK(
3033
- CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
3034
- .wait()));
3095
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3096
+
3097
+ sycl::event copy_event;
3098
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3099
+ if (!g_ggml_sycl_use_async_mem_op) {
3100
+ copy_event.wait();
3101
+ }
3102
+
3035
3103
  GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3036
3104
  GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3037
3105
  int offset_blks = offset / sizeof(block_q4_0);
3038
3106
  auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
3039
3107
  auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3040
3108
 
3041
- stream->parallel_for(
3109
+ auto reorder_event = stream->parallel_for(
3042
3110
  size / sizeof(block_q4_0),
3043
3111
  [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3044
3112
  const block_q4_0* x = (const block_q4_0*)tmp_buf;
@@ -3049,9 +3117,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
3049
3117
  *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3050
3118
  }
3051
3119
  *(d_ptr + ib) = x[ib].d;
3052
- }).wait_and_throw();
3053
-
3054
- sycl::free(tmp_buf, *stream);
3120
+ });
3121
+ if (!g_ggml_sycl_use_async_mem_op) {
3122
+ reorder_event.wait_and_throw();
3123
+ }
3124
+ sycl_ext_free(stream, tmp_buf);
3055
3125
  }
3056
3126
 
3057
3127
  static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3060,14 +3130,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3060
3130
 
3061
3131
  const int nblocks = size / sizeof(block_q4_K);
3062
3132
 
3063
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3064
- SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3133
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3134
+
3135
+ sycl::event copy_event;
3136
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3137
+ if (!g_ggml_sycl_use_async_mem_op) {
3138
+ copy_event.wait();
3139
+ }
3065
3140
 
3066
3141
  auto * qs_ptr = data_device;
3067
3142
  auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3068
3143
  auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3069
3144
 
3070
- stream->parallel_for(nblocks, [=](auto i) {
3145
+ auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3071
3146
  const block_q4_K * x = (const block_q4_K *) tmp_buf;
3072
3147
  const int ib = i;
3073
3148
 
@@ -3080,9 +3155,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3080
3155
  }
3081
3156
 
3082
3157
  dm_ptr[ib] = x[ib].dm;
3083
- }).wait_and_throw();
3084
-
3085
- sycl::free(tmp_buf, *stream);
3158
+ });
3159
+ if (!g_ggml_sycl_use_async_mem_op) {
3160
+ reorder_event.wait_and_throw();
3161
+ }
3162
+ sycl_ext_free(stream, tmp_buf);
3086
3163
  }
3087
3164
 
3088
3165
  static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
@@ -3091,42 +3168,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
3091
3168
 
3092
3169
  const int nblocks = size / sizeof(block_q6_K);
3093
3170
 
3094
- auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3095
- SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3171
+ uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3172
+
3173
+ sycl::event copy_event;
3174
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3175
+ if (!g_ggml_sycl_use_async_mem_op) {
3176
+ copy_event.wait();
3177
+ }
3096
3178
 
3097
3179
  auto * ql_ptr = data_device;
3098
3180
  auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3099
3181
  auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3100
3182
  sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3101
3183
 
3102
- stream
3103
- ->parallel_for(nblocks,
3104
- [=](auto i) {
3105
- const block_q6_K * x = (const block_q6_K *) tmp_buf;
3106
- const int ib = i;
3107
-
3108
- const uint8_t * ql = x[ib].ql;
3109
- const uint8_t * qh = x[ib].qh;
3110
- uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3111
- uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3112
- uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3184
+ auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3185
+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3186
+ const int ib = i;
3113
3187
 
3114
- for (int j = 0; j < QK_K / 2; ++j) {
3115
- base_ql_ptr[j] = ql[j];
3116
- }
3117
- for (int j = 0; j < QK_K / 4; ++j) {
3118
- base_qh_ptr[j] = qh[j];
3119
- }
3188
+ const uint8_t * ql = x[ib].ql;
3189
+ const uint8_t * qh = x[ib].qh;
3190
+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3191
+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3192
+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3120
3193
 
3121
- for (int j = 0; j < QK_K / 16; ++j) {
3122
- base_scales_ptr[j] = x[ib].scales[j];
3123
- }
3194
+ for (int j = 0; j < QK_K / 2; ++j) {
3195
+ base_ql_ptr[j] = ql[j];
3196
+ }
3197
+ for (int j = 0; j < QK_K / 4; ++j) {
3198
+ base_qh_ptr[j] = qh[j];
3199
+ }
3124
3200
 
3125
- dm_ptr[ib] = x[ib].d;
3126
- })
3127
- .wait_and_throw();
3201
+ for (int j = 0; j < QK_K / 16; ++j) {
3202
+ base_scales_ptr[j] = x[ib].scales[j];
3203
+ }
3128
3204
 
3129
- sycl::free(tmp_buf, *stream);
3205
+ dm_ptr[ib] = x[ib].d;
3206
+ });
3207
+ if (!g_ggml_sycl_use_async_mem_op) {
3208
+ reorder_event.wait_and_throw();
3209
+ }
3210
+ sycl_ext_free(stream, tmp_buf);
3130
3211
  }
3131
3212
 
3132
3213
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
@@ -3233,6 +3314,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3233
3314
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3234
3315
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
3235
3316
 
3317
+
3236
3318
  // mmvq and mmq need the __dp4a instruction which is available for gen12+
3237
3319
  // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
3238
3320
  use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
@@ -3240,7 +3322,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3240
3322
  use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
3241
3323
  #endif // SYCL_USE_XMX
3242
3324
 
3243
-
3244
3325
  // mmvq path is faster in the CUDA backend.
3245
3326
  if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
3246
3327
  // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
@@ -3260,26 +3341,27 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3260
3341
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3261
3342
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3262
3343
  }
3263
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3344
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
3264
3345
  // KQV single-batch
3265
3346
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3266
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3347
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
3267
3348
  // KQ + KQV multi-batch
3268
3349
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3269
3350
  } else if (use_dequantize_mul_mat_vec) {
3270
- constexpr bool convert_src1_to_q8_1 = false;
3271
3351
  opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3272
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3352
+ ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
3273
3353
  } else if (use_mul_mat_vec_q) {
3274
- constexpr bool convert_src1_to_q8_1 = true;
3275
3354
  opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3276
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3355
+ ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3356
+ if (extra && extra->optimized_feature.reorder) {
3357
+ ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3358
+ } else {
3359
+ ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3360
+ }
3277
3361
  } else if (use_mul_mat_q) {
3278
- constexpr bool convert_src1_to_q8_1 = true;
3279
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3362
+ ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
3280
3363
  } else {
3281
- constexpr bool convert_src1_to_q8_1 = false;
3282
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3364
+ ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
3283
3365
  }
3284
3366
  }
3285
3367
 
@@ -3446,10 +3528,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3446
3528
  SYCL_CHECK(CHECK_TRY_ERROR(
3447
3529
  stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
3448
3530
 
3531
+ const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
3532
+ assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
3533
+
3449
3534
  {
3450
- sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
3535
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
3451
3536
  sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
3452
- sycl_launch(stream, [&](sycl::handler & cgh) {
3537
+ stream->submit([&](sycl::handler &cgh) {
3453
3538
  sycl::local_accessor<int, 0> src1_row_acc(cgh);
3454
3539
 
3455
3540
  char *__restrict src1_contiguous_get =
@@ -3461,8 +3546,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3461
3546
  size_t ids_nb_ct6 = ids->nb[1];
3462
3547
  size_t ids_nb_ct7 = ids->nb[0];
3463
3548
 
3464
- sycl_parallel_for(
3465
- cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3549
+ cgh.parallel_for(
3550
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3551
+ [=](sycl::nd_item<3> item_ct1) {
3466
3552
  k_copy_src1_to_contiguous(
3467
3553
  src1_original, src1_contiguous_get,
3468
3554
  dev_cur_src1_row_get,
@@ -3491,16 +3577,17 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3491
3577
  ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3492
3578
 
3493
3579
  {
3494
- sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
3580
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
3495
3581
  sycl::range<3> grid_dims(1, 1, num_src1_rows);
3496
- sycl_launch(stream, [&](sycl::handler & cgh) {
3582
+ stream->submit([&](sycl::handler &cgh) {
3497
3583
  const char *__restrict dst_contiguous_get =
3498
3584
  dst_contiguous.get();
3499
3585
  const mmid_row_mapping *__restrict dev_row_mapping_get =
3500
3586
  dev_row_mapping.get();
3501
3587
 
3502
- sycl_parallel_for(
3503
- cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3588
+ cgh.parallel_for(
3589
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3590
+ [=](sycl::nd_item<3> item_ct1) {
3504
3591
  k_copy_dst_from_contiguous(dst_original,
3505
3592
  dst_contiguous_get,
3506
3593
  dev_row_mapping_get,
@@ -3549,6 +3636,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
3549
3636
  ggml_sycl_op_sum_rows(ctx, dst);
3550
3637
  }
3551
3638
 
3639
+ static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3640
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3641
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3642
+ ggml_sycl_op_mean(ctx, dst);
3643
+ }
3644
+
3552
3645
  static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3553
3646
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3554
3647
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -3600,9 +3693,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3600
3693
  case GGML_OP_REPEAT:
3601
3694
  ggml_sycl_repeat(ctx, dst);
3602
3695
  break;
3696
+ case GGML_OP_REPEAT_BACK:
3697
+ ggml_sycl_repeat_back(ctx, dst);
3698
+ break;
3603
3699
  case GGML_OP_GET_ROWS:
3604
3700
  ggml_sycl_get_rows(ctx, dst);
3605
3701
  break;
3702
+ case GGML_OP_SET:
3703
+ ggml_sycl_op_set(ctx, dst);
3704
+ break;
3705
+ case GGML_OP_SET_ROWS:
3706
+ ggml_sycl_op_set_rows(ctx, dst);
3707
+ break;
3606
3708
  case GGML_OP_DUP:
3607
3709
  ggml_sycl_dup(ctx, dst);
3608
3710
  break;
@@ -3610,9 +3712,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3610
3712
  case GGML_OP_ADD1: // TODO: more efficient implementation
3611
3713
  ggml_sycl_add(ctx, dst);
3612
3714
  break;
3715
+ case GGML_OP_ADD_ID:
3716
+ ggml_sycl_add_id(ctx, dst);
3717
+ break;
3613
3718
  case GGML_OP_SUB:
3614
3719
  ggml_sycl_sub(ctx, dst);
3615
3720
  break;
3721
+ case GGML_OP_COUNT_EQUAL:
3722
+ ggml_sycl_count_equal(ctx, dst);
3723
+ break;
3616
3724
  case GGML_OP_ACC:
3617
3725
  ggml_sycl_acc(ctx, dst);
3618
3726
  break;
@@ -3672,6 +3780,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3672
3780
  case GGML_UNARY_OP_ELU:
3673
3781
  ggml_sycl_elu(ctx, dst);
3674
3782
  break;
3783
+ case GGML_UNARY_OP_FLOOR:
3784
+ ggml_sycl_floor(ctx, dst);
3785
+ break;
3786
+ case GGML_UNARY_OP_CEIL:
3787
+ ggml_sycl_ceil(ctx, dst);
3788
+ break;
3789
+ case GGML_UNARY_OP_ROUND:
3790
+ ggml_sycl_round(ctx, dst);
3791
+ break;
3792
+ case GGML_UNARY_OP_TRUNC:
3793
+ ggml_sycl_trunc(ctx, dst);
3794
+ break;
3675
3795
  default:
3676
3796
  return false;
3677
3797
  }
@@ -3687,6 +3807,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3687
3807
  case GGML_GLU_OP_SWIGLU:
3688
3808
  ggml_sycl_swiglu(ctx, dst);
3689
3809
  break;
3810
+ case GGML_GLU_OP_SWIGLU_OAI:
3811
+ ggml_sycl_swiglu_oai(ctx, dst);
3812
+ break;
3813
+ case GGML_GLU_OP_GEGLU_ERF:
3814
+ ggml_sycl_geglu_erf(ctx, dst);
3815
+ break;
3816
+ case GGML_GLU_OP_GEGLU_QUICK:
3817
+ ggml_sycl_geglu_quick(ctx, dst);
3818
+ break;
3690
3819
  default:
3691
3820
  return false;
3692
3821
  }
@@ -3700,6 +3829,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3700
3829
  case GGML_OP_CONCAT:
3701
3830
  ggml_sycl_op_concat(ctx, dst);
3702
3831
  break;
3832
+ case GGML_OP_PAD_REFLECT_1D:
3833
+ ggml_sycl_op_pad_reflect_1d(ctx,dst);
3834
+ break;
3703
3835
  case GGML_OP_UPSCALE:
3704
3836
  ggml_sycl_upscale(ctx, dst);
3705
3837
  break;
@@ -3709,6 +3841,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3709
3841
  case GGML_OP_LEAKY_RELU:
3710
3842
  ggml_sycl_leaky_relu(ctx, dst);
3711
3843
  break;
3844
+ case GGML_OP_RMS_NORM_BACK:
3845
+ ggml_sycl_rms_norm_back(ctx, dst);
3846
+ break;
3712
3847
  case GGML_OP_RMS_NORM:
3713
3848
  ggml_sycl_rms_norm(ctx, dst);
3714
3849
  break;
@@ -3768,6 +3903,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3768
3903
  case GGML_OP_SOFT_MAX:
3769
3904
  ggml_sycl_op_soft_max(ctx, dst);
3770
3905
  break;
3906
+ case GGML_OP_SOFT_MAX_BACK:
3907
+ ggml_sycl_op_soft_max_back(ctx, dst);
3908
+ break;
3771
3909
  case GGML_OP_ROPE:
3772
3910
  ggml_sycl_rope(ctx, dst);
3773
3911
  break;
@@ -3783,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3783
3921
  case GGML_OP_SUM_ROWS:
3784
3922
  ggml_sycl_sum_rows(ctx, dst);
3785
3923
  break;
3924
+ case GGML_OP_MEAN:
3925
+ ggml_sycl_mean(ctx, dst);
3926
+ break;
3786
3927
  case GGML_OP_ARGSORT:
3787
3928
  ggml_sycl_argsort(ctx, dst);
3788
3929
  break;
@@ -3798,6 +3939,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3798
3939
  case GGML_OP_GATED_LINEAR_ATTN:
3799
3940
  ggml_sycl_op_gated_linear_attn(ctx, dst);
3800
3941
  break;
3942
+ case GGML_OP_SSM_CONV:
3943
+ ggml_sycl_ssm_conv(ctx, dst);
3944
+ break;
3945
+ case GGML_OP_ROLL:
3946
+ ggml_sycl_roll(ctx, dst);
3947
+ break;
3948
+ case GGML_OP_ARANGE:
3949
+ ggml_sycl_arange(ctx, dst);
3950
+ break;
3801
3951
  default:
3802
3952
  return false;
3803
3953
  }
@@ -3805,6 +3955,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3805
3955
  return true;
3806
3956
  } catch (sycl::exception & e) {
3807
3957
  std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3958
+ std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
3808
3959
  std::exit(1);
3809
3960
  }
3810
3961
 
@@ -3999,6 +4150,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
3999
4150
  GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
4000
4151
  ggml_op_name(node_op));
4001
4152
  return false;
4153
+ case GGML_OP_MUL_MAT:
4154
+ // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
4155
+ // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
4156
+ // in reordering.
4157
+ if (!g_ggml_sycl_use_async_mem_op) {
4158
+ GGML_LOG_INFO(
4159
+ "%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
4160
+ "oneAPI async memory allocation extension "
4161
+ "%s\n",
4162
+ __func__, ggml_op_name(node_op));
4163
+ return false;
4164
+ }
4002
4165
  }
4003
4166
  }
4004
4167
  return true;
@@ -4100,6 +4263,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
4100
4263
  /* .graph_compute = */ ggml_backend_sycl_graph_compute,
4101
4264
  /* .event_record = */ ggml_backend_sycl_event_record,
4102
4265
  /* .event_wait = */ ggml_backend_sycl_event_wait,
4266
+ /* .graph_optimize = */ NULL,
4103
4267
  };
4104
4268
 
4105
4269
  static ggml_guid_t ggml_backend_sycl_guid() {
@@ -4122,6 +4286,7 @@ struct ggml_backend_sycl_device_context {
4122
4286
  int device;
4123
4287
  std::string name;
4124
4288
  std::string description;
4289
+ int op_offload_min_batch_size;
4125
4290
  };
4126
4291
 
4127
4292
  static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
@@ -4192,6 +4357,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
4192
4357
  }
4193
4358
 
4194
4359
  static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4360
+ ggml_backend_sycl_device_context *sycl_ctx =
4361
+ (ggml_backend_sycl_device_context *)dev->context;
4362
+ int device = sycl_ctx->device;
4195
4363
  switch (op->op) {
4196
4364
  case GGML_OP_CONV_TRANSPOSE_1D:
4197
4365
  {
@@ -4204,21 +4372,26 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4204
4372
  }
4205
4373
  case GGML_OP_UNARY:
4206
4374
  switch (ggml_get_unary_op(op)) {
4375
+ case GGML_UNARY_OP_SGN:
4376
+ case GGML_UNARY_OP_ABS:
4207
4377
  case GGML_UNARY_OP_NEG:
4208
4378
  case GGML_UNARY_OP_STEP:
4379
+ case GGML_UNARY_OP_RELU:
4380
+ case GGML_UNARY_OP_HARDSIGMOID:
4381
+ case GGML_UNARY_OP_TANH:
4209
4382
  case GGML_UNARY_OP_GELU:
4210
4383
  case GGML_UNARY_OP_SILU:
4211
- case GGML_UNARY_OP_RELU:
4212
4384
  case GGML_UNARY_OP_SIGMOID:
4213
- case GGML_UNARY_OP_HARDSIGMOID:
4214
4385
  case GGML_UNARY_OP_HARDSWISH:
4215
4386
  case GGML_UNARY_OP_GELU_QUICK:
4216
4387
  case GGML_UNARY_OP_GELU_ERF:
4217
- case GGML_UNARY_OP_TANH:
4218
4388
  case GGML_UNARY_OP_EXP:
4219
- case GGML_UNARY_OP_SGN:
4220
- case GGML_UNARY_OP_ABS:
4221
4389
  case GGML_UNARY_OP_ELU:
4390
+ return true;
4391
+ case GGML_UNARY_OP_FLOOR:
4392
+ case GGML_UNARY_OP_CEIL:
4393
+ case GGML_UNARY_OP_ROUND:
4394
+ case GGML_UNARY_OP_TRUNC:
4222
4395
  #if defined (GGML_SYCL_F16)
4223
4396
  return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4224
4397
  #else
@@ -4232,6 +4405,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4232
4405
  case GGML_GLU_OP_REGLU:
4233
4406
  case GGML_GLU_OP_GEGLU:
4234
4407
  case GGML_GLU_OP_SWIGLU:
4408
+ case GGML_GLU_OP_SWIGLU_OAI:
4409
+ case GGML_GLU_OP_GEGLU_ERF:
4410
+ case GGML_GLU_OP_GEGLU_QUICK:
4235
4411
  return ggml_is_contiguous_1(op->src[0]);
4236
4412
  default:
4237
4413
  return false;
@@ -4240,15 +4416,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4240
4416
  case GGML_OP_MUL_MAT:
4241
4417
  case GGML_OP_MUL_MAT_ID:
4242
4418
  {
4243
- struct ggml_tensor * a;
4244
- struct ggml_tensor * b;
4245
- if (op->op == GGML_OP_MUL_MAT) {
4246
- a = op->src[0];
4247
- b = op->src[1];
4248
- } else {
4249
- a = op->src[2];
4250
- b = op->src[1];
4251
- }
4419
+ struct ggml_tensor * a = op->src[0];
4420
+ struct ggml_tensor * b = op->src[1];
4421
+
4252
4422
  if (a->ne[3] != b->ne[3]) {
4253
4423
  return false;
4254
4424
  }
@@ -4263,7 +4433,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4263
4433
  }
4264
4434
  }
4265
4435
  ggml_type src0_type = op->src[0]->type;
4266
- if (src0_type == GGML_TYPE_BF16) {
4436
+ if (src0_type == GGML_TYPE_BF16 ) {
4437
+ // TODO: support GGML_TYPE_BF16
4438
+ // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4439
+ return false;
4440
+ }
4441
+
4442
+ // TODO: The configuration below needs more work to be supported with oneDNN
4443
+ if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
4444
+ a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
4445
+ return false;
4446
+ }
4447
+
4448
+ // TODO: This specific configuration can fail with oneDNN and needs more debugging
4449
+ if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
4450
+ a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
4267
4451
  return false;
4268
4452
  }
4269
4453
  return true;
@@ -4285,6 +4469,20 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4285
4469
  return false;
4286
4470
  }
4287
4471
  }
4472
+ case GGML_OP_SET:
4473
+ return (op->type == GGML_TYPE_F32) &&
4474
+ (op->src[0] && op->src[1]) &&
4475
+ (op->src[0]->type == GGML_TYPE_F32) &&
4476
+ (op->src[1]->type == GGML_TYPE_F32);
4477
+
4478
+ case GGML_OP_SET_ROWS:
4479
+ {
4480
+ return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
4481
+ op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
4482
+ op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
4483
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
4484
+ }
4485
+ break;
4288
4486
  case GGML_OP_CPY:
4289
4487
  {
4290
4488
  ggml_type src0_type = op->src[0]->type;
@@ -4354,11 +4552,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4354
4552
  }
4355
4553
  return false;
4356
4554
  }
4357
- case GGML_OP_CONCAT:
4555
+ case GGML_OP_REPEAT_BACK:
4358
4556
  {
4359
4557
  ggml_type src0_type = op->src[0]->type;
4360
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4558
+ return src0_type == GGML_TYPE_F32;
4361
4559
  }
4560
+ case GGML_OP_CONCAT:
4362
4561
  case GGML_OP_DUP:
4363
4562
  case GGML_OP_ARGMAX:
4364
4563
  case GGML_OP_NONE:
@@ -4366,14 +4565,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4366
4565
  case GGML_OP_VIEW:
4367
4566
  case GGML_OP_PERMUTE:
4368
4567
  case GGML_OP_TRANSPOSE:
4369
- return true;
4370
4568
  case GGML_OP_ADD:
4371
4569
  case GGML_OP_ADD1:
4570
+ case GGML_OP_ADD_ID:
4372
4571
  case GGML_OP_SUB:
4572
+ case GGML_OP_COUNT_EQUAL:
4373
4573
  case GGML_OP_MUL:
4374
4574
  case GGML_OP_DIV:
4375
4575
  case GGML_OP_REPEAT:
4376
4576
  return true;
4577
+ case GGML_OP_PAD_REFLECT_1D:
4578
+ return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
4377
4579
  case GGML_OP_SQR:
4378
4580
  case GGML_OP_SQRT:
4379
4581
  case GGML_OP_SIN:
@@ -4386,35 +4588,62 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4386
4588
  return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4387
4589
  #endif
4388
4590
  case GGML_OP_NORM:
4389
- case GGML_OP_RMS_NORM:
4390
4591
  return true;
4391
4592
  case GGML_OP_L2_NORM:
4392
4593
  case GGML_OP_GROUP_NORM:
4393
4594
  return ggml_is_contiguous(op->src[0]);
4595
+ case GGML_OP_RMS_NORM:
4596
+ return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
4597
+ case GGML_OP_RMS_NORM_BACK:
4598
+ return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
4394
4599
  case GGML_OP_SCALE:
4395
4600
  return true;
4396
4601
  case GGML_OP_CONT:
4397
4602
  return op->src[0]->type != GGML_TYPE_BF16;
4398
4603
  case GGML_OP_DIAG_MASK_INF:
4604
+ return true;
4399
4605
  case GGML_OP_SOFT_MAX:
4400
4606
  return true;
4607
+ case GGML_OP_SOFT_MAX_BACK: {
4608
+ float max_bias = 0.0f;
4609
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
4610
+ return max_bias == 0.0f;
4611
+ }
4401
4612
  case GGML_OP_ROPE:
4402
4613
  case GGML_OP_IM2COL:
4403
4614
  return true;
4404
4615
  case GGML_OP_UPSCALE:
4405
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4406
- case GGML_OP_POOL_2D:
4616
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
4407
4617
  case GGML_OP_SUM:
4408
4618
  case GGML_OP_SUM_ROWS:
4619
+ case GGML_OP_MEAN:
4620
+ return ggml_is_contiguous(op->src[0]);
4409
4621
  case GGML_OP_ARGSORT:
4622
+ return op->src[0]->ne[0] * sizeof(int) <=
4623
+ ggml_sycl_info().devices[device].smpbo;
4624
+ case GGML_OP_POOL_2D:
4410
4625
  case GGML_OP_ACC:
4626
+ return true;
4411
4627
  case GGML_OP_PAD:
4628
+ // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
4629
+ if (ggml_get_op_params_i32(op, 8) != 0) {
4630
+ return false;
4631
+ }
4632
+ return ggml_is_contiguous(op->src[0]);
4412
4633
  case GGML_OP_LEAKY_RELU:
4413
4634
  case GGML_OP_TIMESTEP_EMBEDDING:
4414
4635
  case GGML_OP_RWKV_WKV6:
4415
4636
  case GGML_OP_RWKV_WKV7:
4416
4637
  case GGML_OP_GATED_LINEAR_ATTN:
4417
4638
  return true;
4639
+ case GGML_OP_SSM_CONV:
4640
+ return op->type == GGML_TYPE_F32 &&
4641
+ op->src[0]->type == GGML_TYPE_F32 &&
4642
+ op->src[1]->type == GGML_TYPE_F32;
4643
+ case GGML_OP_ROLL:
4644
+ return op->type == GGML_TYPE_F32;
4645
+ case GGML_OP_ARANGE:
4646
+ return op->type == GGML_TYPE_F32;
4418
4647
  default:
4419
4648
  return false;
4420
4649
  }
@@ -4446,9 +4675,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
4446
4675
  }
4447
4676
 
4448
4677
  static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4449
- const int min_batch_size = 32;
4450
- return get_op_batch_size(op) >= min_batch_size;
4451
- GGML_UNUSED(dev);
4678
+ ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
4679
+ return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
4452
4680
  }
4453
4681
 
4454
4682
  static ggml_backend_event_t
@@ -4571,6 +4799,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
4571
4799
  std::lock_guard<std::mutex> lock(mutex);
4572
4800
  if (!initialized) {
4573
4801
  ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
4802
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
4574
4803
 
4575
4804
  for (int i = 0; i < ggml_sycl_info().device_count; i++) {
4576
4805
  ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
@@ -4584,6 +4813,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
4584
4813
  prop, dpct::dev_mgr::instance().get_device(i))));
4585
4814
 
4586
4815
  dev_ctx->description = prop.get_name();
4816
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
4587
4817
 
4588
4818
  ggml_backend_dev_t dev = new ggml_backend_device {
4589
4819
  /* .iface = */ ggml_backend_sycl_device_interface,
@@ -4619,10 +4849,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
4619
4849
  };
4620
4850
 
4621
4851
  ggml_backend_t sycl_backend = new ggml_backend {
4622
- /* .guid = */ ggml_backend_sycl_guid(),
4623
- /* .interface = */ ggml_backend_sycl_interface,
4624
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
4625
- /* .context = */ ctx
4852
+ /* .guid = */ ggml_backend_sycl_guid(),
4853
+ /* .iface = */ ggml_backend_sycl_interface,
4854
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
4855
+ /* .context = */ ctx
4626
4856
  };
4627
4857
 
4628
4858
  return sycl_backend;