whispercpp 1.3.1 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (857) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +7 -3
  3. data/README.md +161 -43
  4. data/Rakefile +45 -13
  5. data/ext/.gitignore +4 -8
  6. data/ext/dependencies.rb +73 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +85 -0
  9. data/ext/ruby_whisper.c +177 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +672 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1303 -0
  15. data/ext/ruby_whisper_segment.c +220 -0
  16. data/ext/ruby_whisper_transcribe.cpp +93 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  19. data/ext/sources/CMakeLists.txt +255 -0
  20. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  21. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  22. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  23. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  24. data/ext/sources/bindings/javascript/package.json +26 -0
  25. data/ext/sources/bindings/javascript/whisper.js +19 -0
  26. data/ext/sources/build-xcframework.sh +547 -0
  27. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  28. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  29. data/ext/sources/cmake/build-info.cmake +60 -0
  30. data/ext/sources/cmake/git-vars.cmake +22 -0
  31. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  32. data/ext/sources/cmake/whisper.pc.in +10 -0
  33. data/ext/sources/examples/CMakeLists.txt +124 -0
  34. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  35. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +133 -0
  36. data/ext/sources/examples/addon.node/addon.cpp +557 -0
  37. data/ext/sources/examples/addon.node/index.js +57 -0
  38. data/ext/sources/examples/addon.node/package.json +16 -0
  39. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  40. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  41. data/ext/sources/examples/bench/bench.cpp +176 -0
  42. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  43. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  44. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  45. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  46. data/ext/sources/examples/cli/cli.cpp +1295 -0
  47. data/ext/sources/examples/coi-serviceworker.js +146 -0
  48. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  49. data/ext/sources/examples/command/command.cpp +800 -0
  50. data/ext/sources/examples/command/commands.txt +9 -0
  51. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  52. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  53. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  54. data/ext/sources/examples/common-ggml.cpp +238 -0
  55. data/ext/sources/examples/common-ggml.h +18 -0
  56. data/ext/sources/examples/common-sdl.cpp +227 -0
  57. data/ext/sources/examples/common-sdl.h +49 -0
  58. data/ext/sources/examples/common-whisper.cpp +175 -0
  59. data/ext/sources/examples/common-whisper.h +24 -0
  60. data/ext/sources/examples/common.cpp +675 -0
  61. data/ext/sources/examples/common.h +322 -0
  62. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  63. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  64. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  65. data/ext/sources/examples/generate-karaoke.sh +57 -0
  66. data/ext/sources/examples/grammar-parser.cpp +423 -0
  67. data/ext/sources/examples/grammar-parser.h +29 -0
  68. data/ext/sources/examples/helpers.js +191 -0
  69. data/ext/sources/examples/json.hpp +24596 -0
  70. data/ext/sources/examples/livestream.sh +112 -0
  71. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  72. data/ext/sources/examples/lsp/lsp.cpp +469 -0
  73. data/ext/sources/examples/lsp/whisper.vim +362 -0
  74. data/ext/sources/examples/miniaudio.h +93468 -0
  75. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  76. data/ext/sources/examples/python/whisper_processor.py +54 -0
  77. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  78. data/ext/sources/examples/quantize/quantize.cpp +226 -0
  79. data/ext/sources/examples/server/CMakeLists.txt +15 -0
  80. data/ext/sources/examples/server/bench.js +29 -0
  81. data/ext/sources/examples/server/httplib.h +10497 -0
  82. data/ext/sources/examples/server/server.cpp +1238 -0
  83. data/ext/sources/examples/server.py +115 -0
  84. data/ext/sources/examples/stb_vorbis.c +5584 -0
  85. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  86. data/ext/sources/examples/stream/stream.cpp +435 -0
  87. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  88. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  89. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  90. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  91. data/ext/sources/examples/sycl/build.sh +22 -0
  92. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  93. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  94. data/ext/sources/examples/talk-llama/CMakeLists.txt +43 -0
  95. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  96. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  97. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  98. data/ext/sources/examples/talk-llama/llama-arch.cpp +1914 -0
  99. data/ext/sources/examples/talk-llama/llama-arch.h +464 -0
  100. data/ext/sources/examples/talk-llama/llama-batch.cpp +843 -0
  101. data/ext/sources/examples/talk-llama/llama-batch.h +147 -0
  102. data/ext/sources/examples/talk-llama/llama-chat.cpp +685 -0
  103. data/ext/sources/examples/talk-llama/llama-chat.h +59 -0
  104. data/ext/sources/examples/talk-llama/llama-context.cpp +2845 -0
  105. data/ext/sources/examples/talk-llama/llama-context.h +297 -0
  106. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  107. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  108. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  109. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  110. data/ext/sources/examples/talk-llama/llama-graph.cpp +1693 -0
  111. data/ext/sources/examples/talk-llama/llama-graph.h +710 -0
  112. data/ext/sources/examples/talk-llama/llama-hparams.cpp +103 -0
  113. data/ext/sources/examples/talk-llama/llama-hparams.h +207 -0
  114. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  115. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  116. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  117. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  118. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  119. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  120. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +44 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +439 -0
  124. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  125. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  126. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  127. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  128. data/ext/sources/examples/talk-llama/llama-memory.cpp +59 -0
  129. data/ext/sources/examples/talk-llama/llama-memory.h +116 -0
  130. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  131. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  132. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1163 -0
  133. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  134. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +282 -0
  135. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  136. data/ext/sources/examples/talk-llama/llama-model.cpp +15114 -0
  137. data/ext/sources/examples/talk-llama/llama-model.h +452 -0
  138. data/ext/sources/examples/talk-llama/llama-quant.cpp +1049 -0
  139. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  140. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  141. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  142. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3377 -0
  143. data/ext/sources/examples/talk-llama/llama-vocab.h +132 -0
  144. data/ext/sources/examples/talk-llama/llama.cpp +358 -0
  145. data/ext/sources/examples/talk-llama/llama.h +1484 -0
  146. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  147. data/ext/sources/examples/talk-llama/speak +40 -0
  148. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  149. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  150. data/ext/sources/examples/talk-llama/talk-llama.cpp +810 -0
  151. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  152. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  153. data/ext/sources/examples/talk-llama/unicode.cpp +854 -0
  154. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  155. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  156. data/ext/sources/examples/vad-speech-segments/speech.cpp +149 -0
  157. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  158. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  159. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  160. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  161. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  162. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  163. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  164. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  165. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +251 -0
  166. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  167. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  168. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  169. data/ext/sources/ggml/CMakeLists.txt +435 -0
  170. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  171. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  172. data/ext/sources/ggml/cmake/common.cmake +50 -0
  173. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  174. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +10 -8
  176. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +11 -1
  178. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  179. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  180. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  181. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  182. data/ext/{ggml → sources/ggml}/include/ggml.h +325 -269
  183. data/ext/sources/ggml/include/gguf.h +202 -0
  184. data/ext/sources/ggml/src/CMakeLists.txt +404 -0
  185. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  186. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  187. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  188. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +92 -53
  189. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +69 -34
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  191. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +75 -0
  192. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  195. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  196. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  197. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +140 -1
  198. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +588 -146
  199. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  200. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  201. data/ext/{ggml → sources/ggml}/src/ggml-common.h +16 -8
  202. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +597 -0
  203. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +3 -2
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +11 -10
  205. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  208. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  209. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  210. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  211. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  212. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  213. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  214. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  215. data/ext/{ggml/src/ggml-cpu/cpu-feats-x86.cpp → sources/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp} +5 -1
  216. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  217. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +3285 -0
  218. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  219. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  220. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  221. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  222. data/ext/sources/ggml/src/ggml-cpu/common.h +73 -0
  223. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +172 -41
  224. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3551 -0
  225. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +78 -25
  226. data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.cpp → sources/ggml/src/ggml-cpu/hbm.cpp} +1 -1
  227. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  228. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  229. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  230. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  231. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3594 -0
  232. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +19 -0
  233. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +9786 -0
  234. data/ext/sources/ggml/src/ggml-cpu/ops.h +118 -0
  235. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  236. data/ext/{ggml/src/ggml-cpu/ggml-cpu-quants.h → sources/ggml/src/ggml-cpu/quants.h} +26 -0
  237. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  238. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  239. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +1184 -0
  240. data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.cpp → sources/ggml/src/ggml-cpu/traits.cpp} +1 -1
  241. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  242. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  243. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +345 -0
  244. data/ext/sources/ggml/src/ggml-cpu/vec.h +1027 -0
  245. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  246. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  247. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  248. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  249. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  250. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  251. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  252. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  253. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  254. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  255. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  256. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  257. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/common.cuh +851 -0
  259. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  260. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  262. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  264. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  266. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  267. data/ext/sources/ggml/src/ggml-cuda/convert.cu +752 -0
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +31 -0
  269. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  270. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  271. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  273. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  275. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  276. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  277. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  278. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1474 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  287. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +638 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  289. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  290. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  291. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  292. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  293. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3647 -0
  294. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  295. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  296. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  297. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  298. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  299. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  300. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  301. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  302. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +506 -0
  304. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +11 -0
  305. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  307. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  308. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  309. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  310. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  312. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  313. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  314. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  315. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  316. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  317. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  318. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  319. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  320. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  321. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  322. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  323. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  324. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  325. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  326. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  327. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +155 -0
  328. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  329. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  330. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +26 -0
  332. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +4 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  430. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  432. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  433. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  434. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  436. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  437. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  438. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  439. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  440. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  441. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  442. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  443. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  444. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  445. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  446. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  447. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  448. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  449. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  450. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  451. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  452. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  453. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  454. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  455. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  456. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  457. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  458. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  459. data/ext/sources/ggml/src/ggml-cuda/unary.cu +378 -0
  460. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +66 -0
  461. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  462. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  463. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  464. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  465. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  466. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  467. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  468. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  469. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +135 -0
  470. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +147 -158
  471. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  481. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  482. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  483. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  484. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  485. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  486. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  487. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  488. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  489. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  490. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  491. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  492. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  493. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  494. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  495. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  496. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  497. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  498. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  499. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  500. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  501. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  502. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  503. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  504. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  505. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  506. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  507. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  508. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  509. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +121 -0
  510. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +649 -0
  511. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2504 -1108
  512. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +2102 -1463
  513. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  514. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  515. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  516. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +110 -0
  517. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +6494 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  526. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  527. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  528. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  529. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  530. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  531. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  532. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  533. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  534. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  535. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  536. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  537. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  538. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  539. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  540. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  541. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  542. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  543. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  544. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  545. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  546. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  547. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  548. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  549. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  550. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  551. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  552. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  553. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  554. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  555. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  556. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  557. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  558. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  559. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  560. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  561. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  562. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  563. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  564. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  565. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  566. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  567. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  568. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  569. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +120 -128
  570. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  571. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +494 -84
  572. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  573. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  574. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +344 -0
  575. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  576. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  577. data/ext/sources/ggml/src/ggml-sycl/common.hpp +561 -0
  578. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +56 -70
  579. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  580. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +8 -12
  581. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  582. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +575 -0
  583. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  584. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +839 -0
  585. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  586. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +823 -0
  587. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +188 -67
  588. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  589. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2987 -0
  590. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1120 -0
  591. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +84 -0
  592. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +102 -0
  593. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +212 -0
  594. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  595. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1197 -1295
  596. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  597. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  598. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  599. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  600. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +60 -81
  601. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  602. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1065 -0
  603. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  604. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +482 -0
  605. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  606. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  607. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  608. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  609. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +111 -0
  610. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +472 -0
  611. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  612. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +38 -28
  613. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  614. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +15 -0
  615. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +26 -0
  616. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +6 -11
  617. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  618. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1307 -0
  619. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +289 -0
  620. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +200 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  623. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3822 -1335
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +31 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +61 -0
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  740. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +203 -36
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  743. data/ext/{ggml → sources/ggml}/src/ggml.c +918 -1782
  744. data/ext/sources/ggml/src/ggml.cpp +26 -0
  745. data/ext/sources/ggml/src/gguf.cpp +1351 -0
  746. data/ext/{include → sources/include}/whisper.h +70 -2
  747. data/ext/sources/src/CMakeLists.txt +145 -0
  748. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  749. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  750. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  751. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +36 -10
  752. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  753. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +29 -3
  754. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  755. data/ext/sources/src/whisper-arch.h +197 -0
  756. data/ext/{src → sources/src}/whisper.cpp +1966 -386
  757. data/ext/sources/tests/CMakeLists.txt +105 -0
  758. data/ext/sources/tests/earnings21/eval.mk +58 -0
  759. data/ext/sources/tests/earnings21/eval.py +68 -0
  760. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  761. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  762. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  763. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  764. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  765. data/ext/sources/tests/en-0-ref.txt +1 -0
  766. data/ext/sources/tests/en-1-ref.txt +1 -0
  767. data/ext/sources/tests/en-2-ref.txt +1 -0
  768. data/ext/sources/tests/es-0-ref.txt +1 -0
  769. data/ext/sources/tests/librispeech/eval.mk +39 -0
  770. data/ext/sources/tests/librispeech/eval.py +47 -0
  771. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  772. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  773. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  774. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  775. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  776. data/ext/sources/tests/run-tests.sh +130 -0
  777. data/ext/sources/tests/test-c.c +3 -0
  778. data/ext/sources/tests/test-vad-full.cpp +54 -0
  779. data/ext/sources/tests/test-vad.cpp +83 -0
  780. data/ext/sources/tests/test-whisper.js +58 -0
  781. data/extsources.rb +39 -5
  782. data/lib/whisper/context.rb +15 -0
  783. data/lib/whisper/model/uri.rb +202 -126
  784. data/lib/whisper/segment.rb +58 -0
  785. data/sig/whisper.rbs +510 -0
  786. data/test/helper.rb +24 -0
  787. data/{tests → test}/test_callback.rb +45 -3
  788. data/{tests → test}/test_error.rb +2 -2
  789. data/{tests → test}/test_model.rb +47 -0
  790. data/test/test_package.rb +51 -0
  791. data/test/test_params.rb +297 -0
  792. data/test/test_segment.rb +146 -0
  793. data/test/test_vad.rb +19 -0
  794. data/test/test_vad_params.rb +103 -0
  795. data/{tests → test}/test_whisper.rb +106 -36
  796. data/whispercpp.gemspec +5 -5
  797. metadata +837 -134
  798. data/ext/cpu.mk +0 -9
  799. data/ext/examples/dr_wav.h +0 -8815
  800. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  801. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  802. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  803. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -10835
  804. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  805. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  806. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  807. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  808. data/ext/ggml/src/ggml-sycl/convert.cpp +0 -547
  809. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  810. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  811. data/ext/ggml/src/ggml-sycl/mmvq.cpp +0 -1015
  812. data/ext/ggml/src/ggml-sycl/norm.cpp +0 -378
  813. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  814. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  815. data/ext/metal-embed.mk +0 -17
  816. data/ext/metal.mk +0 -6
  817. data/ext/ruby_whisper.cpp +0 -1909
  818. data/ext/scripts/get-flags.mk +0 -38
  819. data/lib/whisper.rb +0 -2
  820. data/tests/helper.rb +0 -7
  821. data/tests/test_package.rb +0 -31
  822. data/tests/test_params.rb +0 -160
  823. data/tests/test_segment.rb +0 -83
  824. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  825. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  826. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  827. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  828. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  829. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  830. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  831. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  832. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  833. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  834. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  835. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  836. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  837. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  838. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  839. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  840. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  841. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  842. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  843. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  844. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  845. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  846. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.h → sources/ggml/src/ggml-cpu/hbm.h} +0 -0
  847. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.h → sources/ggml/src/ggml-cpu/traits.h} +0 -0
  848. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  849. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  850. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  851. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  852. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  853. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  854. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
  855. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  856. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  857. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
@@ -37,10 +37,20 @@
37
37
  #include "ggml-backend-impl.h"
38
38
 
39
39
  #include "ggml-sycl/backend.hpp"
40
+ #include "ggml-sycl/common.hpp"
41
+ #include "ggml-sycl/element_wise.hpp"
40
42
  #include "ggml-sycl/presets.hpp"
41
43
  #include "ggml-sycl/gemm.hpp"
44
+ #include "ggml-sycl/sycl_hw.hpp"
45
+ #include "ggml-sycl/getrows.hpp"
46
+ #include "ggml.h"
42
47
 
43
48
  static bool g_sycl_loaded = false;
49
+ int g_ggml_sycl_debug = 0;
50
+ int g_ggml_sycl_disable_optimize = 0;
51
+ int g_ggml_sycl_disable_graph = 0;
52
+ int g_ggml_sycl_disable_dnn = 0;
53
+ int g_ggml_sycl_prioritize_dmmv = 0;
44
54
 
45
55
  static ggml_sycl_device_info ggml_sycl_init() {
46
56
  ggml_sycl_device_info info = {};
@@ -54,30 +64,26 @@ static ggml_sycl_device_info ggml_sycl_init() {
54
64
  GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
55
65
 
56
66
  int64_t total_vram = 0;
57
- #if defined(GGML_SYCL_FORCE_MMQ)
58
- GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__);
59
- #else
60
- GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: no\n", __func__);
61
- #endif
62
- #if defined(SYCL_USE_XMX)
63
- GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
64
- #else
65
- GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
66
- #endif
67
- GGML_LOG_INFO("%s: found %d %s devices:\n", __func__, info.device_count, GGML_SYCL_NAME);
68
-
67
+ /* This is a bit misleading; reserved for later */
68
+ // #if defined(SYCL_USE_XMX)
69
+ // GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
70
+ // #else
71
+ // GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
72
+ // #endif
69
73
  for (int i = 0; i < info.device_count; ++i) {
70
74
  info.devices[i].vmm = 0;
71
75
  dpct::device_info prop;
76
+ sycl::device device = dpct::dev_mgr::instance().get_device(i);
77
+
72
78
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
73
- prop, dpct::dev_mgr::instance().get_device(i))));
79
+ prop, device)));
74
80
 
75
81
  info.default_tensor_split[i] = total_vram;
76
82
  total_vram += prop.get_global_mem_size();
77
83
 
78
84
  info.devices[i].cc =
79
85
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
80
-
86
+ info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
81
87
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
82
88
  }
83
89
 
@@ -92,7 +98,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
92
98
  return info;
93
99
  }
94
100
 
95
- void print_device_detail(int id, sycl::device &device, std::string device_type) {
101
+ static void print_device_detail(int id, sycl::device &device, std::string device_type) {
96
102
 
97
103
  dpct::device_info prop;
98
104
  SYCL_CHECK(CHECK_TRY_ERROR(
@@ -109,13 +115,33 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
109
115
  name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
110
116
 
111
117
  auto global_mem_size = prop.get_global_mem_size()/1000000;
112
-
113
118
  GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
114
119
  name.c_str(), version.c_str(), prop.get_max_compute_units(),
115
120
  prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
116
121
  global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
117
122
  }
118
123
 
124
+ static void print_device_opt_feature(int device_count) {
125
+ GGML_LOG_INFO("SYCL Optimization Feature:\n");
126
+ GGML_LOG_INFO(
127
+ "|ID| Device Type|Reorder|\n");
128
+ GGML_LOG_INFO(
129
+ "|--|-------------------|-------|\n");
130
+ std::map<std::string, size_t> DeviceNums;
131
+ for (int id = 0; id < device_count; ++id) {
132
+ sycl::device device = dpct::dev_mgr::instance().get_device(id);
133
+ std::string backend_type = get_device_backend_and_type(device);
134
+ int type_id = DeviceNums[backend_type]++;
135
+ std::stringstream device_type;
136
+ device_type << "[" << backend_type << ":" << std::to_string(type_id)
137
+ << "]";
138
+ std::string device_type_s = device_type.str();
139
+ device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
140
+ GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
141
+ ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
142
+ }
143
+
144
+ }
119
145
  void ggml_backend_sycl_print_sycl_devices() {
120
146
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
121
147
  int device_count = dpct::dev_mgr::instance().device_count();
@@ -144,6 +170,8 @@ void ggml_backend_sycl_print_sycl_devices() {
144
170
  << "]";
145
171
  print_device_detail(id, device, device_type.str());
146
172
  }
173
+
174
+ print_device_opt_feature(device_count);
147
175
  }
148
176
 
149
177
  static inline int get_sycl_env(const char *env_name, int default_val) {
@@ -164,14 +192,36 @@ static void ggml_check_sycl() try {
164
192
  static bool initialized = false;
165
193
 
166
194
  if (!initialized) {
167
- GGML_LOG_INFO("[SYCL] call ggml_check_sycl\n");
168
195
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
169
- GGML_LOG_INFO("%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
170
-
196
+ g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
197
+ g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
198
+ g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
199
+ g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
200
+ GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
201
+ GGML_LOG_INFO("Running with Environment Variables:\n");
202
+ GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
203
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
204
+ #ifdef GGML_SYCL_GRAPH
205
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
206
+ #else
207
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
208
+ #endif
209
+ #if GGML_SYCL_DNNL
210
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
211
+ #else
212
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
213
+ #endif
214
+ GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
215
+ GGML_LOG_INFO("Build with Macros:\n");
216
+ #if defined(GGML_SYCL_FORCE_MMQ)
217
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
218
+ #else
219
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
220
+ #endif
171
221
  #if defined(GGML_SYCL_F16)
172
- GGML_LOG_INFO("%s: GGML_SYCL_F16: yes\n", __func__);
222
+ GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
173
223
  #else
174
- GGML_LOG_INFO("%s: GGML_SYCL_F16: no\n", __func__);
224
+ GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
175
225
  #endif
176
226
 
177
227
  /* NOT REMOVE, keep it for next optimize for XMX.
@@ -243,19 +293,27 @@ struct ggml_backend_sycl_buffer_context {
243
293
  void * dev_ptr = nullptr;
244
294
  queue_ptr stream;
245
295
  std::string name;
296
+ optimize_feature opt_feature;
297
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
246
298
 
247
- ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
299
+ ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
248
300
  device(device), dev_ptr(dev_ptr), stream(stream) {
249
301
  check_allow_gpu_index(device);
250
302
  name = (GGML_SYCL_NAME + std::to_string(device));
303
+ opt_feature = ggml_sycl_info().devices[device].opt_feature;
251
304
  }
252
305
 
253
-
254
306
  ~ggml_backend_sycl_buffer_context() {
255
307
  if (dev_ptr != nullptr) {
256
308
  ggml_sycl_set_device(device);
257
309
  SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
258
310
  }
311
+
312
+ //release extra used by tensors
313
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
314
+ release_extra_gpu(extra);
315
+ }
316
+
259
317
  }
260
318
  };
261
319
 
@@ -283,18 +341,23 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
283
341
  return ctx->dev_ptr;
284
342
  }
285
343
 
286
- static void
344
+ static enum ggml_status
287
345
  ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
288
346
  ggml_tensor *tensor) try {
347
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
348
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
289
349
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
290
350
 
291
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
351
+ if (tensor->view_src != NULL) {
292
352
  assert(tensor->view_src->buffer->buft == buffer->buft);
293
- tensor->backend = tensor->view_src->backend;
294
- tensor->extra = tensor->view_src->extra;
295
- return;
353
+ return GGML_STATUS_SUCCESS;
354
+ }
355
+ if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
356
+ !g_ggml_sycl_disable_optimize) {
357
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
358
+ tensor->extra = extra;
359
+ ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
296
360
  }
297
-
298
361
 
299
362
  if (ggml_is_quantized(tensor->type)) {
300
363
  // initialize padding to 0 to avoid possible NaN values
@@ -307,6 +370,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
307
370
  padded_size - original_size).wait()));
308
371
  }
309
372
  }
373
+ return GGML_STATUS_SUCCESS;
310
374
  }
311
375
  catch (sycl::exception const &exc) {
312
376
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -318,19 +382,23 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
318
382
  ggml_tensor *tensor,
319
383
  const void *data, size_t offset,
320
384
  size_t size) try {
321
-
385
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
386
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
387
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
322
388
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
323
-
324
389
  ggml_sycl_set_device(ctx->device);
325
390
  auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
326
- SYCL_CHECK(
327
- CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
328
- char* host_buf = (char*)malloc(size);
391
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
392
+ #ifndef _WIN32
393
+ // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
394
+ // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
395
+ char * host_buf = (char *) malloc(size);
329
396
  memcpy(host_buf, data, size);
330
- SYCL_CHECK(
331
- CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
332
- .wait()));
397
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
333
398
  free(host_buf);
399
+ #else
400
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
401
+ #endif
334
402
  }
335
403
  catch (sycl::exception const &exc) {
336
404
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -342,7 +410,9 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
342
410
  const ggml_tensor *tensor,
343
411
  void *data, size_t offset,
344
412
  size_t size) try {
345
-
413
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
414
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
415
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
346
416
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
347
417
 
348
418
  ggml_sycl_set_device(ctx->device);
@@ -358,7 +428,7 @@ catch (sycl::exception const &exc) {
358
428
  std::exit(1);
359
429
  }
360
430
 
361
- void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
431
+ static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
362
432
  const void *ptr_src, size_t size) {
363
433
  char *host_buf = (char *)malloc(size);
364
434
  q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
@@ -370,7 +440,12 @@ static bool
370
440
  ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
371
441
  const ggml_tensor *src,
372
442
  ggml_tensor *dst) try {
373
- if (ggml_backend_buffer_is_sycl(src->buffer)) {
443
+ bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);
444
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
445
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
446
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
447
+ GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
448
+ if (is_cpy_supported) {
374
449
  ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
375
450
  ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
376
451
 
@@ -427,7 +502,8 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
427
502
 
428
503
  static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
429
504
  uint8_t value) try {
430
- ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
505
+ GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size);
506
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
431
507
 
432
508
  ggml_sycl_set_device(ctx->device);
433
509
  queue_ptr stream = ctx->stream;
@@ -444,16 +520,51 @@ catch (sycl::exception const &exc) {
444
520
  std::exit(1);
445
521
  }
446
522
 
523
+ static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
524
+ size_t offset, size_t size) {
525
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
526
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
527
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
528
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
529
+ SYCL_CHECK(ggml_sycl_set_device(ctx->device));
530
+ auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
531
+ if (size == 0) {
532
+ return; // Nothing to do
533
+ }
534
+ if (tensor->data == nullptr) {
535
+ GGML_ABORT("Error: Tensor data pointer is null.\n");
536
+ }
537
+ void * target_ptr = static_cast<char *>(tensor->data) + offset;
538
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
539
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
540
+ }
541
+
542
+ static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
543
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
544
+ if (buffer == nullptr) {
545
+ return;
546
+ }
547
+
548
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
549
+
550
+ if (ctx != nullptr) {
551
+ for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
552
+ release_extra_gpu(extra);
553
+ }
554
+ ctx->tensor_extras.clear(); // reset the tensor_extras vector
555
+ }
556
+ }
557
+
447
558
  static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
448
559
  /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
449
560
  /* .get_base = */ ggml_backend_sycl_buffer_get_base,
450
561
  /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
451
- /* .memset_tensor = */ NULL,
562
+ /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
452
563
  /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
453
564
  /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
454
565
  /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
455
566
  /* .clear = */ ggml_backend_sycl_buffer_clear,
456
- /* .reset = */ NULL,
567
+ /* .reset = */ ggml_backend_sycl_buffer_reset,
457
568
  };
458
569
 
459
570
  // sycl buffer type
@@ -534,12 +645,11 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
534
645
  static std::mutex mutex;
535
646
  std::lock_guard<std::mutex> lock(mutex);
536
647
 
537
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
538
648
 
539
649
  auto dev_count = ggml_backend_sycl_get_device_count();
540
650
 
541
651
  if (device>=dev_count or device<0) {
542
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
652
+ GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
543
653
  device, dev_count-1);
544
654
  GGML_ASSERT(device<dev_count);
545
655
  }
@@ -562,12 +672,12 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
562
672
  return &ggml_backend_sycl_buffer_types[device];
563
673
  }
564
674
 
565
- ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
675
+ static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
566
676
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
567
677
 
568
678
  int device = ctx->device;
569
679
  if (device>=ggml_sycl_info().device_count or device<0) {
570
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
680
+ GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
571
681
  device, ggml_sycl_info().device_count-1);
572
682
  GGML_ASSERT(device<ggml_sycl_info().device_count);
573
683
  }
@@ -664,32 +774,7 @@ struct ggml_backend_sycl_split_buffer_type_context {
664
774
  struct ggml_backend_sycl_split_buffer_context {
665
775
  ~ggml_backend_sycl_split_buffer_context() try {
666
776
  for (ggml_tensor_extra_gpu * extra : tensor_extras) {
667
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
668
- for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
669
- if (extra->events[i][is] != nullptr) {
670
- /*
671
- DPCT1009:206: SYCL uses exceptions to report errors and
672
- does not use the error codes. The original code was
673
- commented out and a warning string was inserted. You
674
- need to rewrite this code.
675
- */
676
- SYCL_CHECK(CHECK_TRY_ERROR(
677
- dpct::destroy_event(extra->events[i][is])));
678
- }
679
- }
680
- if (extra->data_device[i] != nullptr) {
681
- /*
682
- DPCT1009:207: SYCL uses exceptions to report errors and does
683
- not use the error codes. The original code was commented out
684
- and a warning string was inserted. You need to rewrite this
685
- code.
686
- */
687
- ggml_sycl_set_device(i);
688
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
689
- extra->data_device[i], *(streams[i]))));
690
- }
691
- }
692
- delete extra;
777
+ release_extra_gpu(extra, streams);
693
778
  }
694
779
  }
695
780
  catch (sycl::exception const &exc) {
@@ -714,9 +799,11 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff
714
799
  GGML_UNUSED(buffer);
715
800
  }
716
801
 
717
- static void
802
+ static enum ggml_status
718
803
  ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
719
804
  ggml_tensor *tensor) try {
805
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
806
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
720
807
  GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
721
808
 
722
809
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
@@ -727,7 +814,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
727
814
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
728
815
 
729
816
  ctx->tensor_extras.push_back(extra);
730
- ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
817
+ ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
731
818
 
732
819
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
733
820
  int64_t row_low, row_high;
@@ -746,7 +833,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
746
833
  size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
747
834
  }
748
835
 
749
- // FIXME: do not crash if cudaMalloc fails
836
+ // FIXME: do not crash if SYCL Buffer alloc fails
750
837
  // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
751
838
  ggml_sycl_set_device(i);
752
839
  const queue_ptr stream = ctx->streams[i];
@@ -788,8 +875,8 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
788
875
  CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
789
876
  }
790
877
  }
791
- tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
792
878
  tensor->extra = extra;
879
+ return GGML_STATUS_SUCCESS;
793
880
  }
794
881
  catch (sycl::exception const &exc) {
795
882
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -801,6 +888,9 @@ static void
801
888
  ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
802
889
  ggml_tensor *tensor, const void *data,
803
890
  size_t offset, size_t size) try {
891
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
892
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
893
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
804
894
  // split tensors must always be set in their entirety at once
805
895
  GGML_ASSERT(offset == 0);
806
896
  GGML_ASSERT(size == ggml_nbytes(tensor));
@@ -854,6 +944,9 @@ static void
854
944
  ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
855
945
  const ggml_tensor *tensor, void *data,
856
946
  size_t offset, size_t size) try {
947
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
948
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
949
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
857
950
  // split tensors must always be set in their entirety at once
858
951
  GGML_ASSERT(offset == 0);
859
952
  GGML_ASSERT(size == ggml_nbytes(tensor));
@@ -1178,6 +1271,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
1178
1271
  }
1179
1272
  };
1180
1273
 
1274
+ struct ggml_sycl_pool_host : public ggml_sycl_pool {
1275
+ queue_ptr qptr;
1276
+ int device;
1277
+
1278
+ inline static int counter{ 0 };
1279
+
1280
+ struct ggml_sycl_buffer {
1281
+ void * ptr = nullptr;
1282
+ size_t size = 0;
1283
+ };
1284
+
1285
+ // Set arbitrarly to 64
1286
+ static constexpr int MAX_POOL_SIZE{ 64 };
1287
+ std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1288
+ size_t pool_size = 0;
1289
+
1290
+ explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1291
+
1292
+ ~ggml_sycl_pool_host() {
1293
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1294
+ ggml_sycl_buffer & b = buffer_pool[i];
1295
+ if (b.ptr != nullptr) {
1296
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1297
+ b.ptr = nullptr;
1298
+ pool_size -= b.size;
1299
+ b.size = 0;
1300
+ }
1301
+ }
1302
+ counter = 0;
1303
+ }
1304
+
1305
+ void * alloc(size_t size, size_t * actual_size) override {
1306
+ if (counter == MAX_POOL_SIZE) {
1307
+ ggml_sycl_buffer b = buffer_pool[0];
1308
+ void * ptr = b.ptr;
1309
+ *actual_size = b.size;
1310
+ counter = 1;
1311
+ return ptr;
1312
+ }
1313
+ ggml_sycl_buffer & b = buffer_pool[counter];
1314
+
1315
+ if (b.ptr == nullptr) {
1316
+ void * ptr;
1317
+
1318
+ SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1319
+ if (!ptr) {
1320
+ GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1321
+ return nullptr;
1322
+ }
1323
+ pool_size += size;
1324
+ *actual_size = size;
1325
+ counter = counter + 1;
1326
+ return ptr;
1327
+ } else {
1328
+ ++counter;
1329
+ b.size = size;
1330
+ return b.ptr;
1331
+ }
1332
+ }
1333
+
1334
+ void free(void * ptr, size_t size) override {
1335
+ // if the pool is not completed add the pointer to it in place of the first nullptr found.
1336
+ // Otherwise do nothing, pointers will be freed once the pool is deallocated.
1337
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1338
+ ggml_sycl_buffer & b = buffer_pool[i];
1339
+ if (b.ptr == nullptr) {
1340
+ b.ptr = ptr;
1341
+ b.size = size;
1342
+ return;
1343
+ }
1344
+ }
1345
+ }
1346
+ };
1347
+
1348
+ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1349
+ // return pool for the host to speed up memory management
1350
+ return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1351
+ }
1352
+
1181
1353
  std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
1182
1354
  // TBD: NO VMM support
1183
1355
  // if (ggml_sycl_info().devices[device].vmm) {
@@ -1190,9 +1362,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
1190
1362
  // struct ggml_sycl_pool_vmm : public ggml_sycl_pool
1191
1363
 
1192
1364
  /// kernels
1193
-
1194
- typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
1195
- typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
1196
1365
  typedef void (*ggml_sycl_op_mul_mat_t)(
1197
1366
  ggml_backend_sycl_context & ctx,
1198
1367
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -1264,81 +1433,57 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1264
1433
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1265
1434
  }
1266
1435
 
1267
- template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1268
- static void k_get_rows(
1269
- const void * src0, const int32_t * src1, dst_t * dst,
1270
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1271
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1272
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1273
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1274
- size_t s10, size_t s11, size_t s12,
1275
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1276
-
1277
- const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1278
- item_ct1.get_local_id(2)) *
1279
- 2;
1280
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1281
- item_ct1.get_local_id(1);
1282
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1283
- item_ct1.get_local_id(0)) /
1284
- ne12;
1285
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1286
- item_ct1.get_local_id(0)) %
1287
- ne12;
1288
-
1289
- if (i00 >= ne00) {
1290
- return;
1291
- }
1436
+ template <int ElementsPerWI>
1437
+ static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
1438
+ const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
1439
+ /*
1440
+ Quantizes and reorders the resultant q8 tensor in a per row fashion
1441
+ Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1442
+ */
1292
1443
 
1293
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1444
+ auto subgroup_id = it.get_group(0);
1445
+ auto wi_id = it.get_local_id(0);
1294
1446
 
1295
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1296
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
1447
+ const int num_blocks_per_row = kx / QK8_1;
1448
+ auto row = subgroup_id / num_blocks_per_row;
1449
+ auto col = subgroup_id % num_blocks_per_row;
1297
1450
 
1298
- const int ib = i00/qk; // block index
1299
- const int iqs = (i00%qk)/qr; // quant index
1300
- const int iybs = i00 - i00%qk; // dst block start index
1301
- const int y_offset = qr == 1 ? 1 : qk/2;
1451
+ auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
1452
+ auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1302
1453
 
1303
- // dequantize
1304
- dfloat2 v;
1305
- dequantize_kernel(src0_row, ib, iqs, v);
1454
+ auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1455
+ auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
1306
1456
 
1307
- dst_row[iybs + iqs + 0] = v.x();
1308
- dst_row[iybs + iqs + y_offset] = v.y();
1309
- }
1457
+ sycl::vec<float, ElementsPerWI> wi_f32_vals;
1458
+ sycl::vec<int8_t, ElementsPerWI> quantized_values;
1310
1459
 
1311
- template<typename src0_t, typename dst_t>
1312
- static void k_get_rows_float(
1313
- const src0_t * src0, const int32_t * src1, dst_t * dst,
1314
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1315
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1316
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1317
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1318
- size_t s10, size_t s11, size_t s12,
1319
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1460
+ auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1461
+ wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
1320
1462
 
1321
- const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1322
- item_ct1.get_local_id(2);
1323
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1324
- item_ct1.get_local_id(1);
1325
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1326
- item_ct1.get_local_id(0)) /
1327
- ne12;
1328
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1329
- item_ct1.get_local_id(0)) %
1330
- ne12;
1331
-
1332
- if (i00 >= ne00) {
1333
- return;
1463
+ float sum = 0.0f;
1464
+ float amax = 0.0f;
1465
+
1466
+ #pragma unroll(ElementsPerWI)
1467
+ for (int i = 0; i < ElementsPerWI; i++) {
1468
+ sum += wi_f32_vals[i];
1469
+ amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
1470
+ quantized_values[i] = 0;
1334
1471
  }
1472
+ sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
1473
+ amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
1474
+ float d = amax == 0 ? 1 : amax / 127;
1335
1475
 
1336
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1476
+ #pragma unroll(ElementsPerWI)
1477
+ for (int i = 0; i < ElementsPerWI; i++) {
1478
+ quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
1479
+ }
1337
1480
 
1338
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1339
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1481
+ d = amax == 0 ? 0 : d;
1340
1482
 
1341
- dst_row[i00] = src0_row[i00];
1483
+ *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
1484
+ if (wi_id == 0) {
1485
+ *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
1486
+ }
1342
1487
  }
1343
1488
 
1344
1489
  static void mul_mat_p021_f16_f32(
@@ -1451,193 +1596,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1451
1596
  }
1452
1597
  }
1453
1598
 
1454
- static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1455
- const float * xi = (const float *) cxi;
1456
- float * dsti = (float *) cdsti;
1457
-
1458
- *dsti = *xi;
1459
- }
1460
-
1461
- static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1462
- const float * xi = (const float *) cxi;
1463
- sycl::half *dsti = (sycl::half *)cdsti;
1464
-
1465
- *dsti = sycl::vec<float, 1>(*xi)
1466
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
1467
- }
1468
-
1469
- static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
1470
- const sycl::half *xi = (const sycl::half *)cxi;
1471
- sycl::half *dsti = (sycl::half *)cdsti;
1472
-
1473
- *dsti = *xi;
1474
- }
1475
-
1476
- static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
1477
- const sycl::half *xi = (const sycl::half *)cxi;
1478
- float * dsti = (float *) cdsti;
1479
-
1480
- *dsti = *xi;
1481
- }
1482
-
1483
- static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
1484
- const int16_t *xi = (const int16_t *)cxi;
1485
- int16_t *dsti = (int16_t *)cdsti;
1486
-
1487
- *dsti = *xi;
1488
- }
1489
-
1490
- static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
1491
- const int32_t *xi = (const int32_t *)cxi;
1492
- int32_t *dsti = (int32_t *)cdsti;
1493
-
1494
- *dsti = *xi;
1495
- }
1496
-
1497
- template <cpy_kernel_t cpy_1>
1498
- static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
1499
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1500
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1501
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1502
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1503
- item_ct1.get_local_id(2);
1504
-
1505
- if (i >= ne) {
1506
- return;
1507
- }
1508
-
1509
- // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
1510
- // then combine those indices with the corresponding byte offsets to get the total offsets
1511
- const int i03 = i/(ne00 * ne01 * ne02);
1512
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1513
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1514
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1515
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1516
-
1517
- const int i13 = i/(ne10 * ne11 * ne12);
1518
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1519
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1520
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1521
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
1522
-
1523
- cpy_1(cx + x_offset, cdst + dst_offset);
1524
- }
1525
-
1526
- static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
1527
- const float * xi = (const float *) cxi;
1528
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
1529
-
1530
- float amax = 0.0f; // absolute max
1531
-
1532
- for (int j = 0; j < QK8_0; j++) {
1533
- const float v = xi[j];
1534
- amax = sycl::fmax(amax, sycl::fabs((float)v));
1535
- }
1536
-
1537
- const float d = amax / ((1 << 7) - 1);
1538
- const float id = d ? 1.0f/d : 0.0f;
1539
-
1540
- dsti->d = d;
1541
-
1542
- for (int j = 0; j < QK8_0; ++j) {
1543
- const float x0 = xi[j]*id;
1544
-
1545
- dsti->qs[j] = sycl::round((float)x0);
1546
- }
1547
- }
1548
-
1549
- static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
1550
- const float * xi = (const float *) cxi;
1551
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
1552
-
1553
- float amax = 0.0f;
1554
- float vmax = 0.0f;
1555
-
1556
- for (int j = 0; j < QK4_0; ++j) {
1557
- const float v = xi[j];
1558
- if (amax < sycl::fabs((float)v)) {
1559
- amax = sycl::fabs((float)v);
1560
- vmax = v;
1561
- }
1562
- }
1563
-
1564
- const float d = vmax / -8;
1565
- const float id = d ? 1.0f/d : 0.0f;
1566
-
1567
- dsti->d = d;
1568
-
1569
- for (int j = 0; j < QK4_0/2; ++j) {
1570
- const float x0 = xi[0 + j]*id;
1571
- const float x1 = xi[QK4_0/2 + j]*id;
1572
-
1573
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
1574
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
1575
-
1576
- dsti->qs[j] = xi0;
1577
- dsti->qs[j] |= xi1 << 4;
1578
- }
1579
- }
1580
-
1581
- static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
1582
- const float * xi = (const float *) cxi;
1583
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
1584
-
1585
- float vmin = FLT_MAX;
1586
- float vmax = -FLT_MAX;
1587
-
1588
- for (int j = 0; j < QK4_1; ++j) {
1589
- const float v = xi[j];
1590
-
1591
- if (v < vmin) vmin = v;
1592
- if (v > vmax) vmax = v;
1593
- }
1594
-
1595
- const float d = (vmax - vmin) / ((1 << 4) - 1);
1596
- const float id = d ? 1.0f/d : 0.0f;
1597
-
1598
- dsti->dm.x() = d;
1599
- dsti->dm.y() = vmin;
1600
-
1601
- for (int j = 0; j < QK4_1/2; ++j) {
1602
- const float x0 = (xi[0 + j] - vmin)*id;
1603
- const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
1604
-
1605
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
1606
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
1607
-
1608
- dsti->qs[j] = xi0;
1609
- dsti->qs[j] |= xi1 << 4;
1610
- }
1611
- }
1612
-
1613
- template <cpy_kernel_t cpy_blck, int qk>
1614
- static void cpy_f32_q(const char * cx, char * cdst, const int ne,
1615
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1616
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1617
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1618
- const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1619
- item_ct1.get_local_id(2)) *
1620
- qk;
1621
-
1622
- if (i >= ne) {
1623
- return;
1624
- }
1625
-
1626
- const int i03 = i/(ne00 * ne01 * ne02);
1627
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1628
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1629
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1630
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1631
-
1632
- const int i13 = i/(ne10 * ne11 * ne12);
1633
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1634
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1635
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1636
- const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
1637
-
1638
- cpy_blck(cx + x_offset, cdst + dst_offset);
1639
- }
1640
-
1641
1599
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1642
1600
  const sycl::nd_item<3> &item_ct1) {
1643
1601
  const int row = item_ct1.get_group(1);
@@ -1749,17 +1707,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
1749
1707
  dst[i] = scale * x[i];
1750
1708
  }
1751
1709
 
1752
- static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
1753
- const sycl::nd_item<3> &item_ct1) {
1754
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1755
- item_ct1.get_local_id(2);
1756
-
1757
- if (i >= k) {
1758
- return;
1759
- }
1760
-
1761
- dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
1762
- }
1763
1710
 
1764
1711
  template <typename Ti, typename To>
1765
1712
  static void pool2d_nchw_kernel(
@@ -1823,98 +1770,30 @@ static void pool2d_nchw_kernel(
1823
1770
  o_ptr[cur_oh * ow + cur_ow] = res;
1824
1771
  }
1825
1772
 
1826
- template <int qk, int qr, dequantize_kernel_t dq>
1827
- static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
1828
- ggml_tensor *dst, const void *src0_dd,
1829
- const int32_t *src1_dd, float *dst_dd,
1830
- queue_ptr stream) {
1831
-
1832
- GGML_TENSOR_BINARY_OP_LOCALS
1833
-
1834
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1835
- const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
1836
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1837
-
1838
- // strides in elements
1839
- //const size_t s0 = nb0 / ggml_element_size(dst);
1840
- const size_t s1 = nb1 / ggml_element_size(dst);
1841
- const size_t s2 = nb2 / ggml_element_size(dst);
1842
- const size_t s3 = nb3 / ggml_element_size(dst);
1843
-
1844
- const size_t s10 = nb10 / ggml_element_size(src1);
1845
- const size_t s11 = nb11 / ggml_element_size(src1);
1846
- const size_t s12 = nb12 / ggml_element_size(src1);
1847
- //const size_t s13 = nb13 / ggml_element_size(src1);
1848
-
1849
- GGML_ASSERT(ne00 % 2 == 0);
1850
-
1851
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1852
- [=](sycl::nd_item<3> item_ct1) {
1853
- k_get_rows<qk, qr, dq>(
1854
- src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1855
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1856
- });
1857
-
1858
- GGML_UNUSED(dst);
1859
- GGML_UNUSED(ctx);
1860
- }
1861
-
1862
- template <typename src0_t>
1863
- static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1864
- const ggml_tensor *src1, ggml_tensor *dst,
1865
- const src0_t *src0_dd, const int32_t *src1_dd,
1866
- float *dst_dd, queue_ptr stream) {
1867
-
1868
- GGML_TENSOR_BINARY_OP_LOCALS
1869
-
1870
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1871
- const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
1872
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1873
-
1874
- // strides in elements
1875
- //const size_t s0 = nb0 / ggml_element_size(dst);
1876
- const size_t s1 = nb1 / ggml_element_size(dst);
1877
- const size_t s2 = nb2 / ggml_element_size(dst);
1878
- const size_t s3 = nb3 / ggml_element_size(dst);
1879
-
1880
- const size_t s10 = nb10 / ggml_element_size(src1);
1881
- const size_t s11 = nb11 / ggml_element_size(src1);
1882
- const size_t s12 = nb12 / ggml_element_size(src1);
1883
- //const size_t s13 = nb13 / ggml_element_size(src1);
1884
-
1885
- {
1886
- dpct::has_capability_or_fail(stream->get_device(),
1887
- {sycl::aspect::fp16});
1888
-
1889
- stream->parallel_for(
1890
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1891
- [=](sycl::nd_item<3> item_ct1) {
1892
- k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1893
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1894
- });
1895
- }
1896
-
1897
- GGML_UNUSED(dst);
1898
- GGML_UNUSED(ctx);
1899
- }
1900
-
1901
- static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1902
- const int ky, const int kx_padded,
1903
- queue_ptr stream) {
1904
- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1905
- const sycl::range<3> num_blocks(1, ky, block_num_x);
1906
- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1907
- static_assert(QK8_1 % WARP_SIZE == 0);
1908
- const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1909
- {
1910
- dpct::has_capability_or_fail(stream->get_device(),
1911
- {sycl::aspect::fp16});
1773
+ static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1774
+ bool reorder_q8_tensor, queue_ptr stream) {
1775
+ if (reorder_q8_tensor) {
1776
+ auto local_range = std::size_t(WARP_SIZE);
1777
+ auto num_quant_blocks = ky * (kx / QK8_1);
1778
+ auto global_range = num_quant_blocks * local_range;
1779
+ stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
1780
+ [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1781
+ quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1782
+ });
1783
+ } else {
1784
+ const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1785
+ const sycl::range<3> num_blocks(1, ky, block_num_x);
1786
+ int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1787
+ static_assert(QK8_1 % WARP_SIZE == 0);
1788
+ const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1789
+ {
1790
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
1912
1791
 
1913
- stream->parallel_for(
1914
- sycl::nd_range<3>(num_blocks * block_size, block_size),
1915
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1916
- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1917
- });
1792
+ stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
1793
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1794
+ quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1795
+ });
1796
+ }
1918
1797
  }
1919
1798
  }
1920
1799
 
@@ -1933,7 +1812,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1933
1812
 
1934
1813
  stream->parallel_for(
1935
1814
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1936
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1815
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1937
1816
  mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
1938
1817
  nchannels_y, item_ct1);
1939
1818
  });
@@ -1953,7 +1832,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1953
1832
 
1954
1833
  stream->parallel_for(
1955
1834
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1956
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1835
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1957
1836
  mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1958
1837
  row_stride_x, channel_stride_x,
1959
1838
  nchannels_y / nchannels_x, item_ct1);
@@ -1961,231 +1840,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1961
1840
  }
1962
1841
  }
1963
1842
 
1964
- static void
1965
- ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
1966
- const int ne01, const int ne02, const int nb00,
1967
- const int nb01, const int nb02, const int nb03,
1968
- const int ne10, const int ne11, const int ne12,
1969
- const int nb10, const int nb11, const int nb12,
1970
- const int nb13, queue_ptr stream) {
1971
-
1972
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
1973
- {
1974
- dpct::has_capability_or_fail(stream->get_device(),
1975
- {sycl::aspect::fp16});
1976
1843
 
1977
- stream->parallel_for(
1978
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1979
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
1980
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
1981
- [=](sycl::nd_item<3> item_ct1) {
1982
- cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
1983
- nb01, nb02, nb03, ne10, ne11, ne12,
1984
- nb10, nb11, nb12, nb13, item_ct1);
1985
- });
1986
- }
1987
- }
1988
-
1989
- static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
1990
- const int ne00, const int ne01,
1991
- const int ne02, const int nb00,
1992
- const int nb01, const int nb02,
1993
- const int nb03, const int ne10,
1994
- const int ne11, const int ne12,
1995
- const int nb10, const int nb11,
1996
- const int nb12, const int nb13,
1997
- queue_ptr stream) {
1998
-
1999
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2000
- {
2001
- dpct::has_capability_or_fail(stream->get_device(),
2002
- {sycl::aspect::fp16});
2003
-
2004
- stream->parallel_for(
2005
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2006
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2007
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2008
- [=](sycl::nd_item<3> item_ct1) {
2009
- cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2010
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2011
- item_ct1);
2012
- });
2013
- }
2014
- }
2015
-
2016
- static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
2017
- const int ne00, const int ne01,
2018
- const int ne02, const int nb00,
2019
- const int nb01, const int nb02,
2020
- const int nb03, const int ne10,
2021
- const int ne11, const int ne12,
2022
- const int nb10, const int nb11,
2023
- const int nb12, const int nb13,
2024
- queue_ptr stream) {
2025
-
2026
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2027
- {
2028
- dpct::has_capability_or_fail(stream->get_device(),
2029
- {sycl::aspect::fp16});
2030
-
2031
- stream->parallel_for(
2032
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2033
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2034
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2035
- [=](sycl::nd_item<3> item_ct1) {
2036
- cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2037
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2038
- item_ct1);
2039
- });
2040
- }
2041
- }
2042
-
2043
- static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
2044
- const int ne00, const int ne01,
2045
- const int ne02, const int nb00,
2046
- const int nb01, const int nb02,
2047
- const int nb03, const int ne10,
2048
- const int ne11, const int ne12,
2049
- const int nb10, const int nb11,
2050
- const int nb12, const int nb13,
2051
- queue_ptr stream) {
2052
-
2053
- GGML_ASSERT(ne % QK8_0 == 0);
2054
- const int num_blocks = ne / QK8_0;
2055
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2056
- sycl::range<3>(1, 1, 1)),
2057
- [=](sycl::nd_item<3> item_ct1) {
2058
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
2059
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2060
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2061
- item_ct1);
2062
- });
2063
- }
2064
-
2065
- static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
2066
- const int ne00, const int ne01,
2067
- const int ne02, const int nb00,
2068
- const int nb01, const int nb02,
2069
- const int nb03, const int ne10,
2070
- const int ne11, const int ne12,
2071
- const int nb10, const int nb11,
2072
- const int nb12, const int nb13,
2073
- queue_ptr stream) {
2074
-
2075
- GGML_ASSERT(ne % QK4_0 == 0);
2076
- const int num_blocks = ne / QK4_0;
2077
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2078
- sycl::range<3>(1, 1, 1)),
2079
- [=](sycl::nd_item<3> item_ct1) {
2080
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
2081
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2082
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2083
- item_ct1);
2084
- });
2085
- }
2086
-
2087
- static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
2088
- const int ne00, const int ne01,
2089
- const int ne02, const int nb00,
2090
- const int nb01, const int nb02,
2091
- const int nb03, const int ne10,
2092
- const int ne11, const int ne12,
2093
- const int nb10, const int nb11,
2094
- const int nb12, const int nb13,
2095
- queue_ptr stream) {
2096
-
2097
- GGML_ASSERT(ne % QK4_1 == 0);
2098
- const int num_blocks = ne / QK4_1;
2099
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2100
- sycl::range<3>(1, 1, 1)),
2101
- [=](sycl::nd_item<3> item_ct1) {
2102
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
2103
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2104
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2105
- item_ct1);
2106
- });
2107
- }
2108
-
2109
- static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
2110
- const int ne00, const int ne01,
2111
- const int ne02, const int nb00,
2112
- const int nb01, const int nb02,
2113
- const int nb03, const int ne10,
2114
- const int ne11, const int ne12,
2115
- const int nb10, const int nb11,
2116
- const int nb12, const int nb13,
2117
- queue_ptr stream) {
2118
-
2119
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2120
- {
2121
- dpct::has_capability_or_fail(stream->get_device(),
2122
- {sycl::aspect::fp16});
2123
-
2124
- stream->parallel_for(
2125
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2126
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2127
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2128
- [=](sycl::nd_item<3> item_ct1) {
2129
- cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2130
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2131
- item_ct1);
2132
- });
2133
- }
2134
- }
2135
-
2136
- static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
2137
- const int ne00, const int ne01,
2138
- const int ne02, const int nb00,
2139
- const int nb01, const int nb02,
2140
- const int nb03, const int ne10,
2141
- const int ne11, const int ne12,
2142
- const int nb10, const int nb11,
2143
- const int nb12, const int nb13,
2144
- queue_ptr stream) {
2145
-
2146
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2147
- {
2148
- // dpct::has_capability_or_fail(stream->get_device(),
2149
- // {sycl::aspect::fp16});
2150
-
2151
- stream->parallel_for(
2152
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2153
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2154
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2155
- [=](sycl::nd_item<3> item_ct1) {
2156
- cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2157
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2158
- item_ct1);
2159
- });
2160
- }
2161
- }
2162
-
2163
- static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
2164
- const int ne00, const int ne01,
2165
- const int ne02, const int nb00,
2166
- const int nb01, const int nb02,
2167
- const int nb03, const int ne10,
2168
- const int ne11, const int ne12,
2169
- const int nb10, const int nb11,
2170
- const int nb12, const int nb13,
2171
- queue_ptr stream) {
2172
-
2173
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2174
- {
2175
- // dpct::has_capability_or_fail(stream->get_device(),
2176
- // {sycl::aspect::fp16});
2177
-
2178
- stream->parallel_for(
2179
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2180
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2181
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2182
- [=](sycl::nd_item<3> item_ct1) {
2183
- cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2184
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2185
- item_ct1);
2186
- });
2187
- }
2188
- }
2189
1844
 
2190
1845
  static void scale_f32_sycl(const float *x, float *dst, const float scale,
2191
1846
  const int k, queue_ptr stream) {
@@ -2199,18 +1854,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
2199
1854
  });
2200
1855
  }
2201
1856
 
2202
- static void clamp_f32_sycl(const float *x, float *dst, const float min,
2203
- const float max, const int k,
2204
- queue_ptr stream) {
2205
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
2206
- stream->parallel_for(
2207
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2208
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
2209
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
2210
- [=](sycl::nd_item<3> item_ct1) {
2211
- clamp_f32(x, dst, min, max, k, item_ct1);
2212
- });
2213
- }
2214
1857
 
2215
1858
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
2216
1859
  const int nrows, queue_ptr stream) {
@@ -2218,7 +1861,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
2218
1861
  const sycl::range<3> block_nums(1, nrows, 1);
2219
1862
  stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
2220
1863
  [=](sycl::nd_item<3> item_ct1)
2221
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1864
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2222
1865
  k_sum_rows_f32(x, dst, ncols, item_ct1);
2223
1866
  });
2224
1867
  }
@@ -2242,13 +1885,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
2242
1885
  const size_t shared_mem = ncols_pad * sizeof(int);
2243
1886
 
2244
1887
  if (order == GGML_SORT_ORDER_ASC) {
2245
- stream->submit([&](sycl::handler &cgh) {
1888
+ sycl_launch(stream, [&](sycl::handler & cgh) {
2246
1889
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
2247
1890
  sycl::range<1>(shared_mem), cgh);
2248
1891
 
2249
- cgh.parallel_for(
2250
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2251
- [=](sycl::nd_item<3> item_ct1) {
1892
+ sycl_parallel_for(
1893
+ cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2252
1894
  k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
2253
1895
  x, dst, ncols, ncols_pad, item_ct1,
2254
1896
  dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -2256,13 +1898,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
2256
1898
  });
2257
1899
  });
2258
1900
  } else if (order == GGML_SORT_ORDER_DESC) {
2259
- stream->submit([&](sycl::handler &cgh) {
1901
+ sycl_launch(stream, [&](sycl::handler & cgh) {
2260
1902
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
2261
1903
  sycl::range<1>(shared_mem), cgh);
2262
1904
 
2263
- cgh.parallel_for(
2264
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2265
- [=](sycl::nd_item<3> item_ct1) {
1905
+ sycl_parallel_for(
1906
+ cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2266
1907
  k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
2267
1908
  x, dst, ncols, ncols_pad, item_ct1,
2268
1909
  dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -2280,50 +1921,47 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
2280
1921
  const sycl::range<3> block_nums(1, nrows, 1);
2281
1922
  const size_t shared_mem = 256 * sizeof(float);
2282
1923
 
2283
- stream->submit([&](sycl::handler &cgh) {
1924
+ sycl_launch(stream, [&](sycl::handler & cgh) {
2284
1925
  sycl::local_accessor<float, 1> shared_data(
2285
1926
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
2286
1927
  sycl::local_accessor<int, 1> shared_indices(
2287
1928
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
2288
1929
 
2289
- cgh.parallel_for(
2290
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2291
- [=](sycl::nd_item<3> item_ct1) {
2292
- const int tid = item_ct1.get_local_id(2);
2293
- const int row = item_ct1.get_global_id(1);
2294
-
2295
- float max_val = -INFINITY;
2296
- int max_idx = -1;
2297
-
2298
- for (int col = tid; col < ncols; col += 256) {
2299
- float val = x[row * ncols + col];
2300
- if (val > max_val) {
2301
- max_val = val;
2302
- max_idx = col;
2303
- }
2304
- }
1930
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1931
+ const int tid = item_ct1.get_local_id(2);
1932
+ const int row = item_ct1.get_global_id(1);
2305
1933
 
2306
- shared_data[tid] = max_val;
2307
- shared_indices[tid] = max_idx;
2308
- item_ct1.barrier(sycl::access::fence_space::local_space);
1934
+ float max_val = -INFINITY;
1935
+ int max_idx = -1;
2309
1936
 
2310
- for (int stride = 256/2; stride > 0; stride >>= 1) {
2311
- if (tid < stride) {
2312
- float val1 = shared_data[tid];
2313
- float val2 = shared_data[tid + stride];
2314
- if (val2 > val1) {
2315
- shared_data[tid] = val2;
2316
- shared_indices[tid] = shared_indices[tid + stride];
2317
- }
2318
- }
2319
- item_ct1.barrier(sycl::access::fence_space::local_space);
1937
+ for (int col = tid; col < ncols; col += 256) {
1938
+ float val = x[row * ncols + col];
1939
+ if (val > max_val) {
1940
+ max_val = val;
1941
+ max_idx = col;
2320
1942
  }
1943
+ }
2321
1944
 
1945
+ shared_data[tid] = max_val;
1946
+ shared_indices[tid] = max_idx;
1947
+ item_ct1.barrier(sycl::access::fence_space::local_space);
2322
1948
 
2323
- if (tid == 0) {
2324
- dst[row] = shared_indices[0];
1949
+ for (int stride = 256 / 2; stride > 0; stride >>= 1) {
1950
+ if (tid < stride) {
1951
+ float val1 = shared_data[tid];
1952
+ float val2 = shared_data[tid + stride];
1953
+ if (val2 > val1) {
1954
+ shared_data[tid] = val2;
1955
+ shared_indices[tid] = shared_indices[tid + stride];
1956
+ }
2325
1957
  }
2326
- });
1958
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1959
+ }
1960
+
1961
+ if (tid == 0) {
1962
+ dst[row] = shared_indices[0];
1963
+ }
1964
+ });
2327
1965
  });
2328
1966
  }
2329
1967
  static void diag_mask_inf_f32_sycl(const float *x, float *dst,
@@ -2349,12 +1987,22 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
2349
1987
 
2350
1988
  dpct::memcpy_direction kind;
2351
1989
  char * src_ptr;
2352
- if (src->backend == GGML_BACKEND_TYPE_CPU) {
1990
+ if (ggml_backend_buffer_is_host(src->buffer)) {
2353
1991
  kind = dpct::host_to_device;
1992
+ //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
2354
1993
  src_ptr = (char *) src->data;
2355
1994
  // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
2356
- } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
2357
- GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
1995
+ } else if (ggml_backend_buffer_is_sycl(src->buffer)) {
1996
+ // If buffer is a SYCL buffer
1997
+ //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
1998
+ kind = dpct::device_to_device;
1999
+ src_ptr = (char *) src->data;
2000
+ } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
2001
+ /*
2002
+ If buffer is a SYCL split buffer
2003
+ */
2004
+ //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
2005
+ GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
2358
2006
  kind = dpct::device_to_device;
2359
2007
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
2360
2008
  int id;
@@ -2411,65 +2059,6 @@ catch (sycl::exception const &exc) {
2411
2059
  std::exit(1);
2412
2060
  }
2413
2061
 
2414
- static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2415
- const ggml_tensor *src1, ggml_tensor *dst,
2416
- const float *src0_d, const float *src1_d,
2417
- float *dst_d, const queue_ptr &stream) {
2418
-
2419
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
2420
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2421
-
2422
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2423
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
2424
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
2425
-
2426
- const int32_t * src1_i32 = (const int32_t *) src1_d;
2427
-
2428
- switch (src0->type) {
2429
- case GGML_TYPE_F16:
2430
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
2431
- src1_i32, dst_d, stream);
2432
- break;
2433
- case GGML_TYPE_F32:
2434
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2435
- break;
2436
- case GGML_TYPE_Q4_0:
2437
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2438
- break;
2439
- case GGML_TYPE_Q4_1:
2440
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2441
- break;
2442
- case GGML_TYPE_Q5_0:
2443
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2444
- break;
2445
- case GGML_TYPE_Q5_1:
2446
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2447
- break;
2448
- case GGML_TYPE_Q8_0:
2449
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2450
- break;
2451
- default:
2452
- // TODO: k-quants
2453
- GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
2454
- GGML_ABORT("fatal error");
2455
- break;
2456
- }
2457
- }
2458
-
2459
-
2460
- static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2461
- const ggml_tensor *src1, ggml_tensor *dst,
2462
- const float *src0_d, const float *src1_d,
2463
- float *dst_d,
2464
- const queue_ptr &main_stream) {
2465
-
2466
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
2467
-
2468
- GGML_UNUSED(src1);
2469
- GGML_UNUSED(src1_d);
2470
- }
2471
-
2472
-
2473
2062
  inline void ggml_sycl_op_mul_mat_sycl(
2474
2063
  ggml_backend_sycl_context & ctx,
2475
2064
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -2484,33 +2073,31 @@ inline void ggml_sycl_op_mul_mat_sycl(
2484
2073
 
2485
2074
  const int64_t ne00 = src0->ne[0];
2486
2075
  const int64_t ne10 = src1->ne[0];
2487
-
2076
+ GGML_ASSERT(ne00 == ne10);
2488
2077
 
2489
2078
  const int64_t row_diff = row_high - row_low;
2490
2079
 
2491
2080
  int id;
2492
2081
  SYCL_CHECK(
2493
2082
  CHECK_TRY_ERROR(id = get_current_device_id()));
2494
- #if !GGML_SYCL_DNNL
2495
- const int64_t ne0 = dst->ne[0];
2083
+
2084
+ const int64_t ne0 = dst->ne[0]; // used by MKL only
2496
2085
  // the main device has a larger memory buffer to hold the results from all GPUs
2497
2086
  // ldc == nrows of the matrix that cuBLAS writes into
2498
- int ldc = id == ctx.device ? ne0 : row_diff;
2499
- #endif
2087
+ int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
2500
2088
 
2501
2089
  #ifdef GGML_SYCL_F16
2502
2090
  bool use_fp16 = true; // TODO(Yu) SYCL capability check
2503
2091
  #else
2504
2092
  bool use_fp16 = false;
2505
2093
  #endif
2506
- if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
2507
- use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
2508
- dst->op_params[0] == GGML_PREC_DEFAULT) {
2509
-
2510
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
2094
+ if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&
2095
+ row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
2511
2096
  ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
2512
2097
  if (src0->type != GGML_TYPE_F16) {
2513
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
2098
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2099
+ " : converting src0 to fp16");
2100
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
2514
2101
  GGML_ASSERT(to_fp16_sycl != nullptr);
2515
2102
  size_t ne = row_diff*ne00;
2516
2103
  src0_as_f16.alloc(ne);
@@ -2522,7 +2109,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
2522
2109
 
2523
2110
  ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
2524
2111
  if (src1->type != GGML_TYPE_F16) {
2525
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2112
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2113
+ " : converting src1 to fp16");
2114
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2526
2115
  GGML_ASSERT(to_fp16_sycl != nullptr);
2527
2116
  size_t ne = src1_ncols*ne10;
2528
2117
  src1_as_f16.alloc(ne);
@@ -2531,40 +2120,47 @@ inline void ggml_sycl_op_mul_mat_sycl(
2531
2120
  const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
2532
2121
  ? (const sycl::half *)src1->data + src1_padded_row_size
2533
2122
  : src1_as_f16.get();
2534
- ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2535
-
2536
- #if !GGML_SYCL_DNNL
2537
- const sycl::half alpha_f16 = 1.0f;
2538
- const sycl::half beta_f16 = 0.0f;
2539
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2540
- *stream, oneapi::mkl::transpose::trans,
2541
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2542
- &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2543
- src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2544
- dst_f16.get(), dpct::library_data_t::real_half, ldc,
2545
- dpct::library_data_t::real_half)));
2546
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2547
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2548
- #else
2549
- auto dnnl_stream = ctx.stream_dnnl(stream);
2550
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2551
- src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2552
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2553
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2123
+
2124
+ #if GGML_SYCL_DNNL
2125
+ if (!g_ggml_sycl_disable_dnn) {
2126
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2127
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2128
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2129
+ }
2130
+ else
2554
2131
  #endif
2555
- }
2556
- else {
2557
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
2132
+ {
2133
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2134
+
2135
+ const sycl::half alpha_f16 = 1.0f;
2136
+ const sycl::half beta_f16 = 0.0f;
2137
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2138
+ *stream, oneapi::math::transpose::trans,
2139
+ oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2140
+ &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2141
+ src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2142
+ dst_f16.get(), dpct::library_data_t::real_half, ldc,
2143
+ dpct::library_data_t::real_half)));
2144
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2145
+ " : converting dst to fp32");
2146
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2147
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2148
+ }
2149
+ } else {
2558
2150
  ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
2559
2151
  ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
2560
2152
  if (src0->type != GGML_TYPE_F32) {
2561
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
2153
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2154
+ " : converting src0 to fp32");
2155
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
2562
2156
  GGML_ASSERT(to_fp32_sycl != nullptr);
2563
2157
  src0_ddq_as_f32.alloc(row_diff*ne00);
2564
2158
  to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
2565
2159
  }
2566
2160
  if (src1->type != GGML_TYPE_F32) {
2567
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
2161
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2162
+ " : converting src1 to fp32");
2163
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
2568
2164
  GGML_ASSERT(to_fp32_sycl != nullptr);
2569
2165
  src1_ddq_as_f32.alloc(src1_ncols*ne10);
2570
2166
  to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
@@ -2572,25 +2168,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
2572
2168
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2573
2169
  const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2574
2170
 
2575
- #if !GGML_SYCL_DNNL
2576
- const float alpha = 1.0f;
2577
- const float beta = 0.0f;
2578
- # ifdef GGML_SYCL_NVIDIA
2579
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2580
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2581
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2582
- ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2583
- # else
2584
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2585
- *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2586
- dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2587
- dst_dd_i, ldc)));
2588
- # endif
2589
- #else
2590
- auto dnnl_stream = ctx.stream_dnnl(stream);
2591
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2592
- src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2171
+ #if GGML_SYCL_DNNL
2172
+ if (!g_ggml_sycl_disable_dnn) {
2173
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2174
+ DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2175
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2176
+ }
2177
+ else
2593
2178
  #endif
2179
+ {
2180
+ const float alpha = 1.0f;
2181
+ const float beta = 0.0f;
2182
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2183
+ get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2184
+ src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2185
+ dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2186
+ }
2594
2187
  }
2595
2188
  GGML_UNUSED(dst);
2596
2189
  GGML_UNUSED(src1_ddq_i);
@@ -2602,13 +2195,13 @@ catch (sycl::exception const &exc) {
2602
2195
  std::exit(1);
2603
2196
  }
2604
2197
 
2605
- static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2606
- const ggml_tensor *src1, ggml_tensor *dst,
2607
- const float *src0_dd, const float *src1_dd,
2608
- float *dst_dd, const queue_ptr &main_stream) {
2609
-
2610
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2198
+ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2199
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2611
2200
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2201
+ dpct::queue_ptr main_stream = ctx.stream();
2202
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2203
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2204
+ float * dst_dd = static_cast<float *>(dst->data);
2612
2205
 
2613
2206
  const int32_t * opts = (const int32_t *)dst->op_params;
2614
2207
  enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
@@ -2619,8 +2212,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2619
2212
  const int p0 = opts[5];
2620
2213
  const int p1 = opts[6];
2621
2214
 
2622
- const int64_t IH = src0->ne[1];
2623
- const int64_t IW = src0->ne[0];
2215
+ const int64_t IH = dst->src[0]->ne[1];
2216
+ const int64_t IW = dst->src[0]->ne[0];
2624
2217
 
2625
2218
  const int64_t N = dst->ne[3];
2626
2219
  const int64_t OC = dst->ne[2];
@@ -2639,163 +2232,101 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2639
2232
  parallel_elements, src0_dd, dst_dd, op,
2640
2233
  item_ct1);
2641
2234
  });
2642
-
2643
- GGML_UNUSED(src1);
2644
- GGML_UNUSED(src1_dd);
2645
- GGML_UNUSED(ctx);
2646
2235
  }
2647
2236
 
2648
- inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2649
- const ggml_tensor *src1, ggml_tensor *dst,
2650
- const float *src0_dd, const float *src1_dd,
2651
- float *dst_dd,
2652
- const queue_ptr &main_stream) {
2653
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2237
+ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2238
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2654
2239
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2240
+ dpct::queue_ptr main_stream = ctx.stream();
2241
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2242
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2243
+ float * dst_dd = static_cast<float *>(dst->data);
2655
2244
 
2656
- const int64_t ne = ggml_nelements(src0);
2245
+ const int64_t ne = ggml_nelements(dst->src[0]);
2657
2246
 
2658
2247
  sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
2659
-
2660
- GGML_UNUSED(src1);
2661
- GGML_UNUSED(dst);
2662
- GGML_UNUSED(src1_dd);
2663
- GGML_UNUSED(ctx);
2664
2248
  }
2665
2249
 
2666
- inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2667
- const ggml_tensor *src1, ggml_tensor *dst,
2668
- const float *src0_dd, const float *src1_dd,
2669
- float *dst_dd,
2670
- const queue_ptr &main_stream) {
2671
-
2672
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2250
+ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2251
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2673
2252
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2253
+ dpct::queue_ptr main_stream = ctx.stream();
2254
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2255
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2256
+ float * dst_dd = static_cast<float *>(dst->data);
2674
2257
 
2675
- const int64_t ncols = src0->ne[0];
2676
- const int64_t nrows = ggml_nrows(src0);
2258
+ const int64_t ncols = dst->src[0]->ne[0];
2259
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2677
2260
 
2678
2261
  sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2679
-
2680
- GGML_UNUSED(src1);
2681
- GGML_UNUSED(dst);
2682
- GGML_UNUSED(src1_dd);
2683
- GGML_UNUSED(ctx);
2684
2262
  }
2685
2263
 
2686
- inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2687
- const ggml_tensor *src1, ggml_tensor *dst,
2688
- const float *src0_dd, const float *src1_dd,
2689
- float *dst_dd,
2690
- const queue_ptr &main_stream) {
2264
+ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2265
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2266
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
2267
+ dpct::queue_ptr main_stream = ctx.stream();
2268
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2269
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2270
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2691
2271
 
2692
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2693
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2694
2272
 
2695
- const int64_t ncols = src0->ne[0];
2696
- const int64_t nrows = ggml_nrows(src0);
2273
+ const int64_t ncols = dst->src[0]->ne[0];
2274
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2697
2275
 
2698
2276
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2699
2277
 
2700
- argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
2701
-
2702
- GGML_UNUSED(src1);
2703
- GGML_UNUSED(dst);
2704
- GGML_UNUSED(src1_dd);
2705
- GGML_UNUSED(ctx);
2278
+ argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
2706
2279
  }
2707
2280
 
2708
- inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2709
- const ggml_tensor *src1, ggml_tensor *dst,
2710
- const float *src0_dd, const float *src1_dd,
2711
- float *dst_dd,
2712
- const queue_ptr &main_stream) {
2713
-
2714
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2281
+ inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2282
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2715
2283
  GGML_ASSERT( dst->type == GGML_TYPE_I32);
2716
2284
 
2717
- const int64_t ncols = src0->ne[0];
2718
- const int64_t nrows = ggml_nrows(src0);
2285
+ dpct::queue_ptr main_stream = ctx.stream();
2286
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2287
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2288
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2719
2289
 
2720
- argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
2290
+ const int64_t ncols = dst->src[0]->ne[0];
2291
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2721
2292
 
2722
- GGML_UNUSED(src1);
2723
- GGML_UNUSED(dst);
2724
- GGML_UNUSED(src1_dd);
2725
- GGML_UNUSED(ctx);
2293
+ argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2726
2294
  }
2727
2295
 
2728
- inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2729
- const ggml_tensor *src1,
2730
- ggml_tensor *dst, const float *src0_dd,
2731
- const float *src1_dd, float *dst_dd,
2732
- const queue_ptr &main_stream) {
2733
-
2734
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2296
+ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2297
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2735
2298
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2299
+ dpct::queue_ptr main_stream = ctx.stream();
2300
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2301
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2302
+ float * dst_dd = static_cast<float *>(dst->data);
2736
2303
 
2737
- const int64_t ne00 = src0->ne[0];
2738
- const int64_t ne01 = src0->ne[1];
2739
- const int nrows0 = ggml_nrows(src0);
2304
+ const int64_t ne00 = dst->src[0]->ne[0];
2305
+ const int64_t ne01 = dst->src[0]->ne[1];
2306
+ const int nrows0 = ggml_nrows(dst->src[0]);
2740
2307
 
2741
2308
  const int n_past = ((int32_t *) dst->op_params)[0];
2742
2309
 
2743
2310
  diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2744
-
2745
- GGML_UNUSED(src1);
2746
- GGML_UNUSED(dst);
2747
- GGML_UNUSED(src1_dd);
2748
- GGML_UNUSED(ctx);
2749
2311
  }
2750
2312
 
2751
- inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2752
- ggml_tensor *dst, const float *src0_dd,
2753
- const float *src1_dd, float *dst_dd,
2754
- const queue_ptr &main_stream) {
2755
-
2756
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2313
+ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2314
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2757
2315
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2316
+ dpct::queue_ptr main_stream = ctx.stream();
2317
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2318
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2319
+ float * dst_dd = static_cast<float *>(dst->data);
2758
2320
 
2759
2321
  float scale;
2760
2322
  memcpy(&scale, dst->op_params, sizeof(float));
2761
2323
 
2762
- scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
2324
+ scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
2763
2325
  /*
2764
2326
  DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2765
2327
  error codes. The call was replaced with 0. You need to rewrite this code.
2766
2328
  */
2767
2329
  SYCL_CHECK(0);
2768
-
2769
- GGML_UNUSED(src1);
2770
- GGML_UNUSED(dst);
2771
- GGML_UNUSED(src1_dd);
2772
- GGML_UNUSED(ctx);
2773
- }
2774
-
2775
- inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2776
- ggml_tensor *dst, const float *src0_dd,
2777
- const float *src1_dd, float *dst_dd,
2778
- const queue_ptr &main_stream) {
2779
-
2780
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2781
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
2782
-
2783
- float min;
2784
- float max;
2785
- memcpy(&min, dst->op_params, sizeof(float));
2786
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
2787
-
2788
- clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
2789
- /*
2790
- DPCT1010:88: SYCL uses exceptions to report errors and does not use the
2791
- error codes. The call was replaced with 0. You need to rewrite this code.
2792
- */
2793
- SYCL_CHECK(0);
2794
-
2795
- GGML_UNUSED(src1);
2796
- GGML_UNUSED(dst);
2797
- GGML_UNUSED(src1_dd);
2798
- GGML_UNUSED(ctx);
2799
2330
  }
2800
2331
 
2801
2332
  static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
@@ -2857,8 +2388,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2857
2388
  const int nb2 = dst->nb[2];
2858
2389
  const int nb3 = dst->nb[3];
2859
2390
 
2860
- GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2861
- GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2391
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
2392
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));
2862
2393
  GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
2863
2394
 
2864
2395
  GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
@@ -2878,7 +2409,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2878
2409
 
2879
2410
  int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
2880
2411
 
2881
- const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
2412
+ const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2882
2413
  GGML_ASSERT(!(split && ne02 > 1));
2883
2414
  GGML_ASSERT(!(split && ne03 > 1));
2884
2415
  GGML_ASSERT(!(split && ne02 < ne12));
@@ -2966,7 +2497,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2966
2497
  dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2967
2498
 
2968
2499
  if (src1_on_device && src1_is_contiguous) {
2969
- quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2500
+ bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
2501
+ scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2502
+ /*num_src=*/2, " : converting src1 to Q8_1");
2503
+ quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2970
2504
  /*
2971
2505
  DPCT1010:90: SYCL uses exceptions to report errors and does not
2972
2506
  use the error codes. The call was replaced with 0. You need to
@@ -3002,7 +2536,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
3002
2536
  for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
3003
2537
  const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
3004
2538
  const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
3005
-
3006
2539
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
3007
2540
  if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
3008
2541
  continue;
@@ -3071,7 +2604,9 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
3071
2604
  }
3072
2605
 
3073
2606
  if (convert_src1_to_q8_1 && !src1_is_contiguous) {
3074
- quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
2607
+ scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2608
+ /*num_src=*/2, " : converting src1 to Q8_1");
2609
+ quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
3075
2610
  /*
3076
2611
  DPCT1010:92: SYCL uses exceptions to report errors and does
3077
2612
  not use the error codes. The call was replaced with 0. You
@@ -3164,41 +2699,36 @@ catch (sycl::exception const &exc) {
3164
2699
  }
3165
2700
 
3166
2701
 
3167
- static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3168
- GGML_SYCL_DEBUG("call %s\n", __func__);
3169
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
3170
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2702
+ static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2703
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2704
+ ggml_sycl_op_get_rows(ctx, dst);
3171
2705
  }
3172
2706
 
3173
- static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3174
- GGML_SYCL_DEBUG("call %s\n", __func__);
3175
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
3176
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2707
+ static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2708
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2709
+ ggml_sycl_op_norm(ctx, dst);
3177
2710
  }
3178
2711
 
3179
- static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3180
- GGML_SYCL_DEBUG("call %s\n", __func__);
3181
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
3182
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2712
+ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2713
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2714
+ ggml_sycl_op_rms_norm(ctx, dst);
3183
2715
  }
3184
2716
 
3185
- static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3186
- GGML_SYCL_DEBUG("call %s\n", __func__);
3187
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
3188
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2717
+ static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2718
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2719
+ ggml_sycl_op_l2_norm(ctx, dst);
3189
2720
  }
3190
2721
 
3191
- static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3192
- GGML_SYCL_DEBUG("call %s\n", __func__);
3193
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
3194
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2722
+ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2723
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2724
+ ggml_sycl_op_group_norm(ctx, dst);
3195
2725
  }
3196
2726
 
3197
2727
  static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3198
2728
  const ggml_tensor *src1,
3199
2729
  ggml_tensor *dst) try {
3200
2730
  GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
3201
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2731
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3202
2732
  GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
3203
2733
  GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
3204
2734
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -3231,7 +2761,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
3231
2761
  GGML_ASSERT(!ggml_is_transposed(src0));
3232
2762
  GGML_ASSERT(!ggml_is_transposed(src1));
3233
2763
  GGML_ASSERT(!ggml_is_permuted(src0));
3234
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2764
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3235
2765
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
3236
2766
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
3237
2767
 
@@ -3262,146 +2792,182 @@ catch (sycl::exception const &exc) {
3262
2792
  std::exit(1);
3263
2793
  }
3264
2794
 
3265
- static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
3266
- const sycl::half *src1_as_f16, char *dst,
3267
- const void **ptrs_src, void **ptrs_dst,
3268
- int64_t ne12, int64_t ne13, int64_t ne23,
3269
- size_t nb02, size_t nb03, size_t nb12,
3270
- size_t nb13, size_t nbd2, size_t nbd3,
3271
- int64_t r2, int64_t r3,
3272
- const sycl::nd_item<3> &item_ct1) {
3273
- int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
3274
- item_ct1.get_local_id(2);
3275
- int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
3276
- item_ct1.get_local_id(1);
2795
+ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
2796
+ const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2797
+ size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2798
+ int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
2799
+ const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
2800
+ const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
3277
2801
 
3278
2802
  if (i13 >= ne13 || i12 >= ne12) {
3279
2803
  return;
3280
2804
  }
3281
2805
 
3282
- int64_t i03 = i13 / r3;
3283
- int64_t i02 = i12 / r2;
2806
+ const int64_t i03 = i13 / r3;
2807
+ const int64_t i02 = i12 / r2;
3284
2808
 
3285
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
3286
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
3287
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
2809
+ const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2810
+ const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2811
+ uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
2812
+
2813
+ ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2814
+ ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2815
+ ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
3288
2816
  }
3289
2817
 
3290
- static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3291
- const ggml_tensor *src0,
3292
- const ggml_tensor *src1,
3293
- ggml_tensor *dst) try {
2818
+ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2819
+ const ggml_tensor * src1, ggml_tensor * dst) try {
3294
2820
  GGML_ASSERT(!ggml_is_transposed(src0));
3295
2821
  GGML_ASSERT(!ggml_is_transposed(src1));
3296
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2822
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3297
2823
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2824
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
3298
2825
 
3299
2826
  GGML_TENSOR_BINARY_OP_LOCALS
3300
2827
 
2828
+ // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2829
+ // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2830
+ GGML_ASSERT(ggml_is_contiguous(dst));
3301
2831
 
3302
2832
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3303
- queue_ptr main_stream = ctx.stream();;
2833
+ queue_ptr queue = ctx.stream();
3304
2834
 
3305
- void * src0_ddq = src0->data;
3306
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
3307
- float * src1_ddf = (float *) src1->data;
3308
- float * dst_ddf = (float *) dst->data;
2835
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
3309
2836
 
3310
- // convert src1 to fp16
2837
+ const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
2838
+ float * dst_ddf = static_cast<float *>(dst->data);
2839
+
2840
+ const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2841
+ const size_t type_size_src1 = ggml_type_size(src1->type);
2842
+ GGML_ASSERT(nb10 == type_size_src1);
2843
+
2844
+ // SRC1 strides
2845
+ int64_t s11 = nb11 / type_size_src1;
2846
+ int64_t s12 = nb12 / type_size_src1;
2847
+ int64_t s13 = nb13 / type_size_src1;
3311
2848
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
2849
+
2850
+ // convert src1 to fp16
3312
2851
  if (src1->type != GGML_TYPE_F16) {
3313
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2852
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2853
+ " : converting src1 to fp16");
2854
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2855
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
3314
2856
  const int64_t ne_src1 = ggml_nelements(src1);
3315
2857
  src1_f16_alloc.alloc(ne_src1);
3316
- GGML_ASSERT(to_fp16_sycl != nullptr);
3317
- to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
2858
+ to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2859
+
2860
+ src1_f16 = src1_f16_alloc.get();
2861
+ s11 = ne10;
2862
+ s12 = ne11 * s11;
2863
+ s13 = ne12 * s12;
3318
2864
  }
3319
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
3320
- : src1_f16_alloc.get();
3321
2865
 
3322
- char * dst_t;
2866
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
3323
2867
 
3324
- dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
3325
- dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
2868
+ dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
2869
+ dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
3326
2870
 
3327
2871
  // dst strides
3328
2872
  size_t nbd2 = dst->nb[2];
3329
2873
  size_t nbd3 = dst->nb[3];
3330
2874
 
3331
2875
  const float alpha_f32 = 1.0f;
3332
- const float beta_f32 = 0.0f;
2876
+ const float beta_f32 = 0.0f;
3333
2877
 
3334
2878
  const void * alpha = &alpha_f32;
3335
2879
  const void * beta = &beta_f32;
3336
2880
 
3337
- dst_t = (char *) dst_ddf;
3338
-
3339
2881
  GGML_ASSERT(ne12 % ne02 == 0);
3340
2882
  GGML_ASSERT(ne13 % ne03 == 0);
2883
+ GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
2884
+ GGML_ASSERT(ne10 == ne00);
3341
2885
 
3342
2886
  // broadcast factors
3343
- const int64_t r2 = ne12/ne02;
3344
- const int64_t r3 = ne13/ne03;
3345
-
3346
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
3347
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
3348
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3349
- *main_stream, oneapi::mkl::transpose::trans,
3350
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3351
- (const char *)src0_as_f16, dpct::library_data_t::real_half,
3352
- nb01 / nb00, nb02 / nb00,
3353
- (const char *)src1_f16, dpct::library_data_t::real_half,
3354
- nb11 / nb10, nb12 / nb10, beta,
3355
- (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
3356
- ne12 * ne13, cu_compute_type)));
3357
- } else {
3358
- const int ne23 = ne12*ne13;
3359
-
3360
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
3361
- ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
3362
-
3363
- sycl::range<3> block_dims(1, ne12, ne13);
3364
- /*
3365
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
3366
- the limit. To get the device limit, query
3367
- info::device::max_work_group_size. Adjust the work-group size if needed.
3368
- */
3369
- {
3370
- dpct::has_capability_or_fail(main_stream->get_device(),
3371
- {sycl::aspect::fp16});
3372
-
3373
- main_stream->submit([&](sycl::handler &cgh) {
3374
- const void **ptrs_src_get = ptrs_src.get();
3375
- void **ptrs_dst_get = ptrs_dst.get();
3376
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
3377
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
3378
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
3379
- [=](sycl::nd_item<3> item_ct1) {
3380
- k_compute_batched_ptrs(
3381
- src0_as_f16, src1_f16,
3382
- dst_t, ptrs_src_get,
3383
- ptrs_dst_get, ne12, ne13, ne23,
3384
- nb02, nb03, nb12_scaled, nb13_scaled,
3385
- nbd2, nbd3, r2, r3, item_ct1);
3386
- });
2887
+ const int64_t r2 = ne12 / ne02;
2888
+ const int64_t r3 = ne13 / ne03;
2889
+
2890
+ #if GGML_SYCL_DNNL
2891
+ if (!g_ggml_sycl_disable_dnn) {
2892
+ auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2893
+ (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2894
+
2895
+ DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2896
+ src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2897
+ src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2898
+ dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2899
+ };
2900
+
2901
+ if (r2 == 1 && r3 == 1) {
2902
+ if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2903
+ dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2904
+ }
2905
+ else {
2906
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2907
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2908
+ const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2909
+ float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2910
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2911
+ }
2912
+ }
2913
+ } else {
2914
+ // iterate over batches from smaller set of matrices (matrix 0)
2915
+ for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2916
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2917
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2918
+ const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2919
+ float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2920
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2921
+ }
2922
+ }
2923
+ }
2924
+ }
2925
+ else
2926
+ #endif
2927
+ {
2928
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2929
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2930
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2931
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2932
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2933
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2934
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2935
+ } else {
2936
+ const int ne23 = ne12 * ne13;
2937
+
2938
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
2939
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
2940
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2941
+
2942
+ sycl::range<3> block_dims(1, ne12, ne13);
2943
+ queue->submit([&](sycl::handler & cgh) {
2944
+ const void ** ptrs_src_get = ptrs_src.get();
2945
+ void ** ptrs_dst_get = ptrs_dst.get();
2946
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2947
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2948
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2949
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2950
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2951
+ });
3387
2952
  });
2953
+
2954
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2955
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2956
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2957
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2958
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
3388
2959
  }
3389
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3390
- *main_stream, oneapi::mkl::transpose::trans,
3391
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3392
- (const void **)(ptrs_src.get() + 0 * ne23),
3393
- dpct::library_data_t::real_half, nb01 / nb00,
3394
- (const void **)(ptrs_src.get() + 1 * ne23),
3395
- dpct::library_data_t::real_half, nb11 / nb10, beta,
3396
- (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3397
- cu_compute_type)));
3398
2960
  }
2961
+ } catch (const sycl::exception & exc) {
2962
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2963
+ std::exit(1);
3399
2964
  }
3400
- catch (sycl::exception const &exc) {
3401
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3402
- << ", line:" << __LINE__ << std::endl;
3403
- std::exit(1);
3404
- }
2965
+
2966
+ enum class mul_mat_algo {
2967
+ DMMV = 0,
2968
+ MMVQ = 1,
2969
+ MUL_MAT_SYCL = 2,
2970
+ };
3405
2971
 
3406
2972
  inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
3407
2973
  // TODO: accuracy issues in MMQ
@@ -3409,7 +2975,39 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
3409
2975
  return false;
3410
2976
  }
3411
2977
 
3412
- bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2978
+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
2979
+ switch (type) {
2980
+ case GGML_TYPE_Q4_0:
2981
+ return true;
2982
+ case GGML_TYPE_Q4_K:
2983
+ case GGML_TYPE_Q6_K:
2984
+ return !g_ggml_sycl_prioritize_dmmv;
2985
+ default:
2986
+ return false;
2987
+ }
2988
+ }
2989
+
2990
+ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
2991
+ switch (type) {
2992
+ case GGML_TYPE_Q4_0:
2993
+ return true;
2994
+ default:
2995
+ return false;
2996
+ }
2997
+ }
2998
+
2999
+ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
3000
+ switch (type) {
3001
+ case GGML_TYPE_Q4_0:
3002
+ case GGML_TYPE_Q4_K:
3003
+ case GGML_TYPE_Q6_K:
3004
+ return true;
3005
+ default:
3006
+ return false;
3007
+ }
3008
+ }
3009
+
3010
+ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3413
3011
  switch (type) {
3414
3012
  case GGML_TYPE_Q4_0:
3415
3013
  case GGML_TYPE_Q4_1:
@@ -3428,12 +3026,190 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3428
3026
  }
3429
3027
  }
3430
3028
 
3029
+ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
3030
+ dpct::queue_ptr stream) {
3031
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3032
+ SYCL_CHECK(
3033
+ CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
3034
+ .wait()));
3035
+ GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3036
+ GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3037
+ int offset_blks = offset / sizeof(block_q4_0);
3038
+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
3039
+ auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3040
+
3041
+ stream->parallel_for(
3042
+ size / sizeof(block_q4_0),
3043
+ [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3044
+ const block_q4_0* x = (const block_q4_0*)tmp_buf;
3045
+ const int ib = i;
3046
+
3047
+ for (int j = 0; j < QK4_0/2; j ++)
3048
+ {
3049
+ *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3050
+ }
3051
+ *(d_ptr + ib) = x[ib].d;
3052
+ }).wait_and_throw();
3053
+
3054
+ sycl::free(tmp_buf, *stream);
3055
+ }
3056
+
3057
+ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3058
+ GGML_ASSERT(size % sizeof(block_q4_K) == 0);
3059
+ GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
3060
+
3061
+ const int nblocks = size / sizeof(block_q4_K);
3062
+
3063
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3064
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3065
+
3066
+ auto * qs_ptr = data_device;
3067
+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3068
+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3069
+
3070
+ stream->parallel_for(nblocks, [=](auto i) {
3071
+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
3072
+ const int ib = i;
3073
+
3074
+ for (int j = 0; j < QK_K / 2; ++j) {
3075
+ qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
3076
+ }
3077
+
3078
+ for (int j = 0; j < K_SCALE_SIZE; ++j) {
3079
+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
3080
+ }
3081
+
3082
+ dm_ptr[ib] = x[ib].dm;
3083
+ }).wait_and_throw();
3084
+
3085
+ sycl::free(tmp_buf, *stream);
3086
+ }
3087
+
3088
+ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3089
+ GGML_ASSERT(size % sizeof(block_q6_K) == 0);
3090
+ GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
3091
+
3092
+ const int nblocks = size / sizeof(block_q6_K);
3093
+
3094
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3095
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3096
+
3097
+ auto * ql_ptr = data_device;
3098
+ auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3099
+ auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3100
+ sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3101
+
3102
+ stream
3103
+ ->parallel_for(nblocks,
3104
+ [=](auto i) {
3105
+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3106
+ const int ib = i;
3107
+
3108
+ const uint8_t * ql = x[ib].ql;
3109
+ const uint8_t * qh = x[ib].qh;
3110
+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3111
+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3112
+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3113
+
3114
+ for (int j = 0; j < QK_K / 2; ++j) {
3115
+ base_ql_ptr[j] = ql[j];
3116
+ }
3117
+ for (int j = 0; j < QK_K / 4; ++j) {
3118
+ base_qh_ptr[j] = qh[j];
3119
+ }
3120
+
3121
+ for (int j = 0; j < QK_K / 16; ++j) {
3122
+ base_scales_ptr[j] = x[ib].scales[j];
3123
+ }
3124
+
3125
+ dm_ptr[ib] = x[ib].d;
3126
+ })
3127
+ .wait_and_throw();
3128
+
3129
+ sycl::free(tmp_buf, *stream);
3130
+ }
3131
+
3132
+ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3133
+ uint8_t * data_device = (uint8_t *) src0->data;
3134
+ size_t ncols = src0->ne[0];
3135
+ size_t nrows = src0->ne[1];
3136
+ size_t size = ggml_nbytes(src0);
3137
+
3138
+ switch (src0->type) {
3139
+ case GGML_TYPE_Q4_0:
3140
+ reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3141
+ break;
3142
+ case GGML_TYPE_Q4_K:
3143
+ reorder_qw_q4_k(data_device, size, 0, stream);
3144
+ break;
3145
+ case GGML_TYPE_Q6_K:
3146
+ reorder_qw_q6_k(data_device, size, 0, stream);
3147
+ break;
3148
+ default:
3149
+ GGML_ABORT("reorder_qw() called with unsupported type");
3150
+ break;
3151
+ }
3152
+ }
3153
+
3154
+ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
3155
+ return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
3156
+ ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
3157
+ dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
3158
+ dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
3159
+ }
3160
+
3161
+ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
3162
+ ggml_tensor * dst, mul_mat_algo mm_algorithm) {
3163
+ if (!should_reorder_tensor(*ctx, dst)) {
3164
+ return;
3165
+ }
3166
+
3167
+ ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3168
+ if (!extra || extra->optimized_feature.reorder) {
3169
+ return; // Skip permutations and already reordered tensors
3170
+ }
3171
+
3172
+ switch (mm_algorithm) {
3173
+ case mul_mat_algo::DMMV:
3174
+ if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
3175
+ return;
3176
+ }
3177
+ break;
3178
+ case mul_mat_algo::MMVQ:
3179
+ if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
3180
+ return;
3181
+ }
3182
+ break;
3183
+ case mul_mat_algo::MUL_MAT_SYCL:
3184
+ if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
3185
+ return;
3186
+ }
3187
+ break;
3188
+ }
3189
+
3190
+ reorder_qw(src0, ctx->stream());
3191
+ extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
3192
+ }
3193
+
3194
+
3195
+ static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3196
+ return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3197
+ src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3198
+ }
3199
+
3200
+ static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3201
+ return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3202
+ src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3203
+ }
3204
+
3431
3205
  static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3206
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3432
3207
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3433
3208
  int64_t min_compute_capability = INT_MAX;
3434
3209
 
3435
3210
  if (split) {
3436
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3211
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
3212
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3437
3213
  auto & tensor_split = buft_ctx->tensor_split;
3438
3214
  for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
3439
3215
  // skip devices that are not going to do any work:
@@ -3446,17 +3222,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3446
3222
  }
3447
3223
  }
3448
3224
  } else {
3449
- min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3225
+ min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3450
3226
  }
3451
3227
 
3452
3228
  // check data types and tensor shapes for custom matrix multiplication kernels:
3453
- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
3454
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3455
- && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3229
+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
3456
3230
 
3457
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
3458
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3459
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3231
+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
3460
3232
 
3461
3233
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3462
3234
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -3468,9 +3240,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3468
3240
  use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
3469
3241
  #endif // SYCL_USE_XMX
3470
3242
 
3243
+
3471
3244
  // mmvq path is faster in the CUDA backend.
3472
- if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
3245
+ if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
3246
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3247
+ // is enabled takes precedence over DMMV, the current if-else implementation
3248
+ // requires disabling DMMV if both conditions are met
3249
+ || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
3473
3250
  use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3251
+ }
3474
3252
 
3475
3253
  if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
3476
3254
  // TODO: Refactor and cleanup of mul mat dispatching.
@@ -3482,20 +3260,26 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3482
3260
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3483
3261
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3484
3262
  }
3485
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3263
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3486
3264
  // KQV single-batch
3487
3265
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3488
3266
  } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3489
3267
  // KQ + KQV multi-batch
3490
3268
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3491
3269
  } else if (use_dequantize_mul_mat_vec) {
3492
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
3270
+ constexpr bool convert_src1_to_q8_1 = false;
3271
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3272
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3493
3273
  } else if (use_mul_mat_vec_q) {
3494
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3274
+ constexpr bool convert_src1_to_q8_1 = true;
3275
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3276
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3495
3277
  } else if (use_mul_mat_q) {
3496
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
3278
+ constexpr bool convert_src1_to_q8_1 = true;
3279
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3497
3280
  } else {
3498
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
3281
+ constexpr bool convert_src1_to_q8_1 = false;
3282
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3499
3283
  }
3500
3284
  }
3501
3285
 
@@ -3565,9 +3349,11 @@ __dpct_inline__ static void k_copy_dst_from_contiguous(
3565
3349
  }
3566
3350
  }
3567
3351
 
3568
- static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3569
- const ggml_tensor *src1,
3352
+ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3570
3353
  ggml_tensor *dst) try {
3354
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
3355
+ const ggml_tensor *src0 = dst->src[0];
3356
+ const ggml_tensor *src1 = dst->src[1];
3571
3357
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
3572
3358
 
3573
3359
  const ggml_tensor *ids = dst->src[2];
@@ -3621,8 +3407,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
3621
3407
  const int64_t i2 = i12;
3622
3408
 
3623
3409
  src0_row.data = src0_original + i02*nb02;
3624
- src1_row.data = src1_original + + i11*nb11 + i12*nb12;
3625
- dst_row.data = dst_original + i1*nb1 + i2*nb2;
3410
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
3411
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
3626
3412
 
3627
3413
  ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3628
3414
  }
@@ -3663,7 +3449,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
3663
3449
  {
3664
3450
  sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
3665
3451
  sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
3666
- stream->submit([&](sycl::handler &cgh) {
3452
+ sycl_launch(stream, [&](sycl::handler & cgh) {
3667
3453
  sycl::local_accessor<int, 0> src1_row_acc(cgh);
3668
3454
 
3669
3455
  char *__restrict src1_contiguous_get =
@@ -3675,9 +3461,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
3675
3461
  size_t ids_nb_ct6 = ids->nb[1];
3676
3462
  size_t ids_nb_ct7 = ids->nb[0];
3677
3463
 
3678
- cgh.parallel_for(
3679
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3680
- [=](sycl::nd_item<3> item_ct1) {
3464
+ sycl_parallel_for(
3465
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3681
3466
  k_copy_src1_to_contiguous(
3682
3467
  src1_original, src1_contiguous_get,
3683
3468
  dev_cur_src1_row_get,
@@ -3708,15 +3493,14 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
3708
3493
  {
3709
3494
  sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
3710
3495
  sycl::range<3> grid_dims(1, 1, num_src1_rows);
3711
- stream->submit([&](sycl::handler &cgh) {
3496
+ sycl_launch(stream, [&](sycl::handler & cgh) {
3712
3497
  const char *__restrict dst_contiguous_get =
3713
3498
  dst_contiguous.get();
3714
3499
  const mmid_row_mapping *__restrict dev_row_mapping_get =
3715
3500
  dev_row_mapping.get();
3716
3501
 
3717
- cgh.parallel_for(
3718
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3719
- [=](sycl::nd_item<3> item_ct1) {
3502
+ sycl_parallel_for(
3503
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3720
3504
  k_copy_dst_from_contiguous(dst_original,
3721
3505
  dst_contiguous_get,
3722
3506
  dev_row_mapping_get,
@@ -3733,117 +3517,52 @@ catch (sycl::exception const &exc) {
3733
3517
  std::exit(1);
3734
3518
  }
3735
3519
 
3736
- static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3737
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
3738
- }
3739
-
3740
- static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3741
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
3742
- }
3743
-
3744
- static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3745
- ggml_tensor *dst) try {
3746
- const int64_t ne = ggml_nelements(src0);
3747
- GGML_ASSERT(ne == ggml_nelements(src1));
3748
-
3749
- GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
3750
- GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
3751
-
3752
- GGML_TENSOR_BINARY_OP_LOCALS01;
3753
-
3754
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3755
- queue_ptr main_stream = ctx.stream();
3756
-
3757
- char * src0_ddc = (char *) src0->data;
3758
- char * src1_ddc = (char *) src1->data;
3759
-
3760
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
3761
- ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3762
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
3763
- ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3764
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
3765
- ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3766
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
3767
- ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3768
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
3769
- ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3770
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
3771
- ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3772
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
3773
- ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3774
- } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
3775
- ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3776
- } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
3777
- ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3778
- } else {
3779
- GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__,
3780
- ggml_type_name(src0->type), ggml_type_name(src1->type));
3781
- GGML_ABORT("fatal error");
3782
- }
3783
-
3784
- GGML_UNUSED(dst);
3785
- }
3786
- catch (sycl::exception const &exc) {
3787
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3788
- << ", line:" << __LINE__ << std::endl;
3789
- std::exit(1);
3790
- }
3791
-
3792
- static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3793
- // TODO: why do we pass dst as src1 here?
3794
- ggml_sycl_cpy(ctx, src0, dst, nullptr);
3795
- GGML_UNUSED(src1);
3796
- }
3797
-
3798
- static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3799
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
3800
- }
3801
-
3802
- static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3803
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
3520
+ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3521
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3522
+ ggml_sycl_op_scale(ctx, dst);
3804
3523
  }
3805
3524
 
3806
- static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3807
- GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
3808
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
3525
+ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3526
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3527
+ ggml_sycl_op_diag_mask_inf(ctx, dst);
3809
3528
  }
3810
3529
 
3811
- static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3812
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
3530
+ static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3531
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3532
+ ggml_sycl_op_pool2d(ctx, dst);
3813
3533
  }
3814
3534
 
3815
- static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3816
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
3535
+ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3536
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3537
+ ggml_sycl_op_im2col(ctx, dst);
3817
3538
  }
3818
3539
 
3819
- static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3820
- GGML_ASSERT(ggml_is_contiguous(src0));
3821
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
3540
+ static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3541
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3542
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3543
+ ggml_sycl_op_sum(ctx, dst);
3822
3544
  }
3823
3545
 
3824
- static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3825
- GGML_ASSERT(ggml_is_contiguous(src0));
3826
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
3546
+ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3547
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3548
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3549
+ ggml_sycl_op_sum_rows(ctx, dst);
3827
3550
  }
3828
3551
 
3829
- static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3830
- GGML_ASSERT(ggml_is_contiguous(src0));
3831
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
3552
+ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3553
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3554
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3555
+ ggml_sycl_op_argsort(ctx, dst);
3832
3556
  }
3833
3557
 
3834
- static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3835
- GGML_ASSERT(ggml_is_contiguous(src0));
3836
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
3558
+ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3559
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3560
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3561
+ ggml_sycl_op_argmax(ctx, dst);
3837
3562
  }
3838
3563
 
3839
- static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3840
- GGML_UNUSED(src0);
3841
- GGML_UNUSED(src1);
3842
- GGML_UNUSED(dst);
3843
- GGML_UNUSED(ctx);
3844
- }
3845
3564
 
3846
- void ggml_sycl_set_main_device(const int main_device) try {
3565
+ static void ggml_sycl_set_main_device(const int main_device) try {
3847
3566
  if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
3848
3567
  return;
3849
3568
  }
@@ -3864,192 +3583,229 @@ catch (sycl::exception const &exc) {
3864
3583
  std::exit(1);
3865
3584
  }
3866
3585
 
3867
- bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
3586
+ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
3868
3587
  if (!g_sycl_loaded) return false;
3869
3588
 
3870
- ggml_sycl_func_t func;
3589
+ if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
3590
+ ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
3591
+ }
3871
3592
 
3872
- switch (tensor->op) {
3593
+ switch (dst->op) {
3873
3594
  case GGML_OP_ARGMAX:
3874
- func = ggml_sycl_argmax;
3595
+ ggml_sycl_argmax(ctx, dst);
3875
3596
  break;
3876
3597
  case GGML_OP_CONV_TRANSPOSE_1D:
3877
- func = ggml_sycl_op_conv_transpose_1d;
3598
+ ggml_sycl_op_conv_transpose_1d(ctx, dst);
3878
3599
  break;
3879
3600
  case GGML_OP_REPEAT:
3880
- func = ggml_sycl_repeat;
3601
+ ggml_sycl_repeat(ctx, dst);
3881
3602
  break;
3882
3603
  case GGML_OP_GET_ROWS:
3883
- func = ggml_sycl_get_rows;
3604
+ ggml_sycl_get_rows(ctx, dst);
3884
3605
  break;
3885
3606
  case GGML_OP_DUP:
3886
- func = ggml_sycl_dup;
3607
+ ggml_sycl_dup(ctx, dst);
3887
3608
  break;
3888
3609
  case GGML_OP_ADD:
3889
3610
  case GGML_OP_ADD1: // TODO: more efficient implementation
3890
- func = ggml_sycl_add;
3611
+ ggml_sycl_add(ctx, dst);
3891
3612
  break;
3892
3613
  case GGML_OP_SUB:
3893
- func = ggml_sycl_sub;
3614
+ ggml_sycl_sub(ctx, dst);
3894
3615
  break;
3895
3616
  case GGML_OP_ACC:
3896
- func = ggml_sycl_acc;
3617
+ ggml_sycl_acc(ctx, dst);
3897
3618
  break;
3898
3619
  case GGML_OP_MUL:
3899
- func = ggml_sycl_mul;
3620
+ ggml_sycl_mul(ctx, dst);
3900
3621
  break;
3901
3622
  case GGML_OP_LOG:
3902
- func = ggml_sycl_log;
3623
+ ggml_sycl_log(ctx, dst);
3903
3624
  break;
3904
3625
  case GGML_OP_DIV:
3905
- func = ggml_sycl_div;
3626
+ ggml_sycl_div(ctx, dst);
3906
3627
  break;
3907
3628
  case GGML_OP_UNARY:
3908
- switch (ggml_get_unary_op(tensor)) {
3629
+ switch (ggml_get_unary_op(dst)) {
3909
3630
  case GGML_UNARY_OP_NEG:
3910
- func = ggml_sycl_neg;
3631
+ ggml_sycl_neg(ctx, dst);
3911
3632
  break;
3912
3633
  case GGML_UNARY_OP_STEP:
3913
- func = ggml_sycl_step;
3634
+ ggml_sycl_step(ctx, dst);
3914
3635
  break;
3915
3636
  case GGML_UNARY_OP_GELU:
3916
- func = ggml_sycl_gelu;
3637
+ ggml_sycl_gelu(ctx, dst);
3917
3638
  break;
3918
3639
  case GGML_UNARY_OP_SILU:
3919
- func = ggml_sycl_silu;
3640
+ ggml_sycl_silu(ctx, dst);
3920
3641
  break;
3921
3642
  case GGML_UNARY_OP_GELU_QUICK:
3922
- func = ggml_sycl_gelu_quick;
3643
+ ggml_sycl_gelu_quick(ctx, dst);
3644
+ break;
3645
+ case GGML_UNARY_OP_GELU_ERF:
3646
+ ggml_sycl_gelu_erf(ctx, dst);
3923
3647
  break;
3924
3648
  case GGML_UNARY_OP_TANH:
3925
- func = ggml_sycl_tanh;
3649
+ ggml_sycl_tanh(ctx, dst);
3926
3650
  break;
3927
3651
  case GGML_UNARY_OP_RELU:
3928
- func = ggml_sycl_relu;
3652
+ ggml_sycl_relu(ctx, dst);
3929
3653
  break;
3930
3654
  case GGML_UNARY_OP_SIGMOID:
3931
- func = ggml_sycl_sigmoid;
3655
+ ggml_sycl_sigmoid(ctx, dst);
3932
3656
  break;
3933
3657
  case GGML_UNARY_OP_HARDSIGMOID:
3934
- func = ggml_sycl_hardsigmoid;
3658
+ ggml_sycl_hardsigmoid(ctx, dst);
3935
3659
  break;
3936
3660
  case GGML_UNARY_OP_HARDSWISH:
3937
- func = ggml_sycl_hardswish;
3661
+ ggml_sycl_hardswish(ctx, dst);
3938
3662
  break;
3939
3663
  case GGML_UNARY_OP_EXP:
3940
- func = ggml_sycl_exp;
3664
+ ggml_sycl_exp(ctx, dst);
3665
+ break;
3666
+ case GGML_UNARY_OP_SGN:
3667
+ ggml_sycl_sgn(ctx, dst);
3668
+ break;
3669
+ case GGML_UNARY_OP_ABS:
3670
+ ggml_sycl_abs(ctx, dst);
3671
+ break;
3672
+ case GGML_UNARY_OP_ELU:
3673
+ ggml_sycl_elu(ctx, dst);
3674
+ break;
3675
+ default:
3676
+ return false;
3677
+ }
3678
+ break;
3679
+ case GGML_OP_GLU:
3680
+ switch (ggml_get_glu_op(dst)) {
3681
+ case GGML_GLU_OP_REGLU:
3682
+ ggml_sycl_reglu(ctx, dst);
3683
+ break;
3684
+ case GGML_GLU_OP_GEGLU:
3685
+ ggml_sycl_geglu(ctx, dst);
3686
+ break;
3687
+ case GGML_GLU_OP_SWIGLU:
3688
+ ggml_sycl_swiglu(ctx, dst);
3941
3689
  break;
3942
3690
  default:
3943
3691
  return false;
3944
3692
  }
3945
3693
  break;
3946
3694
  case GGML_OP_NORM:
3947
- func = ggml_sycl_norm;
3695
+ ggml_sycl_norm(ctx, dst);
3948
3696
  break;
3949
3697
  case GGML_OP_GROUP_NORM:
3950
- func = ggml_sycl_group_norm;
3698
+ ggml_sycl_group_norm(ctx, dst);
3951
3699
  break;
3952
3700
  case GGML_OP_CONCAT:
3953
- func = ggml_sycl_op_concat;
3701
+ ggml_sycl_op_concat(ctx, dst);
3954
3702
  break;
3955
3703
  case GGML_OP_UPSCALE:
3956
- func = ggml_sycl_upscale;
3704
+ ggml_sycl_upscale(ctx, dst);
3957
3705
  break;
3958
3706
  case GGML_OP_PAD:
3959
- func = ggml_sycl_pad;
3707
+ ggml_sycl_pad(ctx, dst);
3960
3708
  break;
3961
3709
  case GGML_OP_LEAKY_RELU:
3962
- func = ggml_sycl_leaky_relu;
3710
+ ggml_sycl_leaky_relu(ctx, dst);
3963
3711
  break;
3964
3712
  case GGML_OP_RMS_NORM:
3965
- func = ggml_sycl_rms_norm;
3713
+ ggml_sycl_rms_norm(ctx, dst);
3714
+ break;
3715
+ case GGML_OP_L2_NORM:
3716
+ ggml_sycl_l2_norm(ctx, dst);
3966
3717
  break;
3967
3718
  case GGML_OP_MUL_MAT:
3968
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
3719
+ if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3969
3720
  return false;
3970
3721
  }
3971
- func = ggml_sycl_mul_mat;
3722
+ /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
3723
+ ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
3972
3724
  break;
3973
3725
  case GGML_OP_MUL_MAT_ID:
3974
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
3726
+ if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3975
3727
  return false;
3976
3728
  }
3977
- func = ggml_sycl_mul_mat_id;
3729
+ ggml_sycl_mul_mat_id(ctx, dst);
3978
3730
  break;
3979
3731
  case GGML_OP_OUT_PROD:
3980
- func = ggml_sycl_op_out_prod;
3732
+ ggml_sycl_op_out_prod(ctx, dst);
3981
3733
  break;
3982
3734
  case GGML_OP_SCALE:
3983
- func = ggml_sycl_scale;
3735
+ ggml_sycl_scale(ctx, dst);
3984
3736
  break;
3985
3737
  case GGML_OP_SQR:
3986
- func = ggml_sycl_sqr;
3738
+ ggml_sycl_sqr(ctx, dst);
3987
3739
  break;
3988
3740
  case GGML_OP_SQRT:
3989
- func = ggml_sycl_sqrt;
3741
+ ggml_sycl_sqrt(ctx, dst);
3990
3742
  break;
3991
3743
  case GGML_OP_SIN:
3992
- func = ggml_sycl_sin;
3744
+ ggml_sycl_sin(ctx, dst);
3993
3745
  break;
3994
3746
  case GGML_OP_COS:
3995
- func = ggml_sycl_cos;
3747
+ ggml_sycl_cos(ctx, dst);
3996
3748
  break;
3997
3749
  case GGML_OP_CLAMP:
3998
- func = ggml_sycl_clamp;
3750
+ ggml_sycl_clamp(ctx, dst);
3999
3751
  break;
4000
3752
  case GGML_OP_CPY:
4001
- func = ggml_sycl_cpy;
3753
+ ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
4002
3754
  break;
4003
3755
  case GGML_OP_CONT:
4004
- func = ggml_sycl_dup;
3756
+ ggml_sycl_dup(ctx, dst);
4005
3757
  break;
4006
3758
  case GGML_OP_NONE:
4007
3759
  case GGML_OP_RESHAPE:
4008
3760
  case GGML_OP_VIEW:
4009
3761
  case GGML_OP_PERMUTE:
4010
3762
  case GGML_OP_TRANSPOSE:
4011
- func = ggml_sycl_nop;
3763
+ GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
4012
3764
  break;
4013
3765
  case GGML_OP_DIAG_MASK_INF:
4014
- func = ggml_sycl_diag_mask_inf;
3766
+ ggml_sycl_diag_mask_inf(ctx, dst);
4015
3767
  break;
4016
3768
  case GGML_OP_SOFT_MAX:
4017
- func = ggml_sycl_soft_max;
3769
+ ggml_sycl_op_soft_max(ctx, dst);
4018
3770
  break;
4019
3771
  case GGML_OP_ROPE:
4020
- func = ggml_sycl_rope;
3772
+ ggml_sycl_rope(ctx, dst);
4021
3773
  break;
4022
3774
  case GGML_OP_IM2COL:
4023
- func = ggml_sycl_im2col;
3775
+ ggml_sycl_im2col(ctx, dst);
4024
3776
  break;
4025
3777
  case GGML_OP_POOL_2D:
4026
- func = ggml_sycl_pool2d;
3778
+ ggml_sycl_pool2d(ctx, dst);
4027
3779
  break;
4028
3780
  case GGML_OP_SUM:
4029
- func = ggml_sycl_sum;
3781
+ ggml_sycl_sum(ctx, dst);
4030
3782
  break;
4031
3783
  case GGML_OP_SUM_ROWS:
4032
- func = ggml_sycl_sum_rows;
3784
+ ggml_sycl_sum_rows(ctx, dst);
4033
3785
  break;
4034
3786
  case GGML_OP_ARGSORT:
4035
- func = ggml_sycl_argsort;
3787
+ ggml_sycl_argsort(ctx, dst);
4036
3788
  break;
4037
3789
  case GGML_OP_TIMESTEP_EMBEDDING:
4038
- func = ggml_sycl_op_timestep_embedding;
3790
+ ggml_sycl_op_timestep_embedding(ctx, dst);
4039
3791
  break;
4040
3792
  case GGML_OP_RWKV_WKV6:
4041
- func = ggml_sycl_op_rwkv_wkv6;
3793
+ ggml_sycl_op_rwkv_wkv6(ctx, dst);
3794
+ break;
3795
+ case GGML_OP_RWKV_WKV7:
3796
+ ggml_sycl_op_rwkv_wkv7(ctx, dst);
3797
+ break;
3798
+ case GGML_OP_GATED_LINEAR_ATTN:
3799
+ ggml_sycl_op_gated_linear_attn(ctx, dst);
4042
3800
  break;
4043
3801
  default:
4044
3802
  return false;
4045
3803
  }
4046
3804
 
4047
- if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
4048
- ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
4049
- }
4050
-
4051
- func(ctx, tensor->src[0], tensor->src[1], tensor);
4052
3805
  return true;
3806
+ } catch (sycl::exception & e) {
3807
+ std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3808
+ std::exit(1);
4053
3809
  }
4054
3810
 
4055
3811
  GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
@@ -4112,6 +3868,9 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
4112
3868
  ggml_tensor *tensor,
4113
3869
  const void *data, size_t offset,
4114
3870
  size_t size) try {
3871
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3872
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
3873
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4115
3874
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4116
3875
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4117
3876
 
@@ -4130,13 +3889,16 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
4130
3889
  const ggml_tensor *tensor,
4131
3890
  void *data, size_t offset,
4132
3891
  size_t size) try {
3892
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3893
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
3894
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4133
3895
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4134
3896
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4135
3897
 
4136
3898
  GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
4137
3899
  const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4138
3900
  SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4139
- data, (const char *)tensor->data + offset, size).wait()));
3901
+ data, (const char *)tensor->data + offset, size)));
4140
3902
  }
4141
3903
  catch (sycl::exception const &exc) {
4142
3904
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4148,7 +3910,13 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
4148
3910
  const ggml_tensor *src,
4149
3911
  ggml_tensor *dst) try {
4150
3912
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4151
- if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
3913
+ bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&
3914
+ ggml_backend_buffer_is_sycl(src->buffer);
3915
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3916
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
3917
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
3918
+ GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
3919
+ if (is_cpy_supported) {
4152
3920
  /*
4153
3921
  DPCT1009:215: SYCL uses exceptions to report errors and does not use the
4154
3922
  error codes. The original code was commented out and a warning string
@@ -4156,7 +3924,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
4156
3924
  */
4157
3925
  const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4158
3926
  SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4159
- dst->data, src->data, ggml_nbytes(dst)).wait()));
3927
+ dst->data, src->data, ggml_nbytes(dst))));
4160
3928
  return true;
4161
3929
  }
4162
3930
 
@@ -4169,6 +3937,7 @@ catch (sycl::exception const &exc) {
4169
3937
  }
4170
3938
 
4171
3939
  static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
3940
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4172
3941
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4173
3942
  const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4174
3943
  SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
@@ -4181,11 +3950,9 @@ catch (sycl::exception const &exc) {
4181
3950
  std::exit(1);
4182
3951
  }
4183
3952
 
4184
- static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4185
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
3953
+ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
4186
3954
  ggml_sycl_set_main_device(sycl_ctx->device);
4187
3955
 
4188
-
4189
3956
  for (int i = 0; i < cgraph->n_nodes; i++) {
4190
3957
  ggml_tensor * node = cgraph->nodes[i];
4191
3958
  if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
@@ -4205,7 +3972,82 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
4205
3972
  }
4206
3973
  GGML_ASSERT(ok);
4207
3974
  }
3975
+ }
4208
3976
 
3977
+ #ifdef GGML_SYCL_GRAPH
3978
+ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
3979
+ if (ggml_sycl_info().device_count > 1) {
3980
+ // A sycl_ex::command_graph object can only be created for a single device
3981
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
3982
+ return false;
3983
+ }
3984
+
3985
+ for (int i = 0; i < cgraph->n_nodes; i++) {
3986
+ const ggml_op node_op = cgraph->nodes[i]->op;
3987
+ switch (node_op) {
3988
+ default:
3989
+ break;
3990
+ case GGML_OP_CONCAT:
3991
+ // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
3992
+ // but wait() can't be called on the events returned by a queue recording
3993
+ // to a graph.
3994
+ [[fallthrough]];
3995
+ case GGML_OP_MUL_MAT_ID:
3996
+ // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
3997
+ // submitting a memcpy operation, but wait() can't be called on a queue that
3998
+ // is recording to a graph.
3999
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
4000
+ ggml_op_name(node_op));
4001
+ return false;
4002
+ }
4003
+ }
4004
+ return true;
4005
+ }
4006
+ #endif
4007
+
4008
+ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4009
+ auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
4010
+
4011
+ #ifdef GGML_SYCL_GRAPH
4012
+ bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
4013
+ if (use_sycl_graph) {
4014
+ const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
4015
+ if (!graph_support) {
4016
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
4017
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4018
+ return GGML_STATUS_SUCCESS;
4019
+ }
4020
+
4021
+ sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
4022
+
4023
+ model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
4024
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4025
+ model_sycl_graph.end_recording();
4026
+
4027
+ const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
4028
+ if (!sycl_ctx->exec_graph || !graph_update_support) {
4029
+ auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
4030
+ model_sycl_graph.finalize();
4031
+ sycl_ctx->exec_graph = std::make_unique<
4032
+ sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
4033
+ } else {
4034
+ try {
4035
+ sycl_ctx->exec_graph->update(model_sycl_graph);
4036
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
4037
+ } catch (sycl::exception const & e) {
4038
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
4039
+ auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
4040
+ sycl_ctx->exec_graph = std::make_unique<
4041
+ sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
4042
+ }
4043
+ }
4044
+
4045
+ sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
4046
+ } else
4047
+ #endif
4048
+ {
4049
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
4050
+ }
4209
4051
  return GGML_STATUS_SUCCESS;
4210
4052
  }
4211
4053
 
@@ -4229,7 +4071,7 @@ catch (sycl::exception const &exc)
4229
4071
  }
4230
4072
 
4231
4073
  static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
4232
-
4074
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4233
4075
  sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
4234
4076
 
4235
4077
  if (ggml_backend_is_sycl(backend)) {
@@ -4270,7 +4112,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
4270
4112
  }
4271
4113
 
4272
4114
  int ggml_backend_sycl_get_device_count() {
4273
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
4274
4115
  return ggml_sycl_info().device_count;
4275
4116
  }
4276
4117
 
@@ -4360,7 +4201,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4360
4201
  return true;
4361
4202
  }
4362
4203
  return false;
4363
- } break;
4204
+ }
4364
4205
  case GGML_OP_UNARY:
4365
4206
  switch (ggml_get_unary_op(op)) {
4366
4207
  case GGML_UNARY_OP_NEG:
@@ -4372,9 +4213,26 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4372
4213
  case GGML_UNARY_OP_HARDSIGMOID:
4373
4214
  case GGML_UNARY_OP_HARDSWISH:
4374
4215
  case GGML_UNARY_OP_GELU_QUICK:
4216
+ case GGML_UNARY_OP_GELU_ERF:
4375
4217
  case GGML_UNARY_OP_TANH:
4376
4218
  case GGML_UNARY_OP_EXP:
4377
- return ggml_is_contiguous(op->src[0]);
4219
+ case GGML_UNARY_OP_SGN:
4220
+ case GGML_UNARY_OP_ABS:
4221
+ case GGML_UNARY_OP_ELU:
4222
+ #if defined (GGML_SYCL_F16)
4223
+ return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4224
+ #else
4225
+ return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4226
+ #endif
4227
+ default:
4228
+ return false;
4229
+ }
4230
+ case GGML_OP_GLU:
4231
+ switch (ggml_get_glu_op(op)) {
4232
+ case GGML_GLU_OP_REGLU:
4233
+ case GGML_GLU_OP_GEGLU:
4234
+ case GGML_GLU_OP_SWIGLU:
4235
+ return ggml_is_contiguous_1(op->src[0]);
4378
4236
  default:
4379
4237
  return false;
4380
4238
  }
@@ -4409,7 +4267,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4409
4267
  return false;
4410
4268
  }
4411
4269
  return true;
4412
- } break;
4270
+ }
4413
4271
  case GGML_OP_OUT_PROD:
4414
4272
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
4415
4273
  case GGML_OP_GET_ROWS:
@@ -4426,11 +4284,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4426
4284
  default:
4427
4285
  return false;
4428
4286
  }
4429
- } break;
4287
+ }
4430
4288
  case GGML_OP_CPY:
4431
4289
  {
4432
4290
  ggml_type src0_type = op->src[0]->type;
4433
4291
  ggml_type src1_type = op->src[1]->type;
4292
+ if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
4293
+ return true;
4294
+ }
4434
4295
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4435
4296
  return true;
4436
4297
  }
@@ -4452,35 +4313,85 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4452
4313
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4453
4314
  return true;
4454
4315
  }
4316
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
4317
+ return true;
4318
+ }
4319
+ if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
4320
+ return true;
4321
+ }
4322
+ if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
4323
+ return true;
4324
+ }
4325
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
4326
+ return true;
4327
+ }
4328
+ if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
4329
+ return true;
4330
+ }
4331
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
4332
+ return true;
4333
+ }
4334
+ if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
4335
+ return true;
4336
+ }
4337
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
4338
+ return true;
4339
+ }
4340
+ if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
4341
+ return true;
4342
+ }
4343
+ if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
4344
+ return true;
4345
+ }
4346
+ if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
4347
+ return true;
4348
+ }
4349
+ if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
4350
+ return true;
4351
+ }
4352
+ if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
4353
+ return true;
4354
+ }
4455
4355
  return false;
4456
- } break;
4356
+ }
4457
4357
  case GGML_OP_CONCAT:
4458
4358
  {
4459
4359
  ggml_type src0_type = op->src[0]->type;
4460
4360
  return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4461
- } break;
4361
+ }
4462
4362
  case GGML_OP_DUP:
4463
4363
  case GGML_OP_ARGMAX:
4464
4364
  case GGML_OP_NONE:
4465
4365
  case GGML_OP_RESHAPE:
4466
- case GGML_OP_REPEAT:
4467
4366
  case GGML_OP_VIEW:
4468
4367
  case GGML_OP_PERMUTE:
4469
4368
  case GGML_OP_TRANSPOSE:
4470
- case GGML_OP_NORM:
4369
+ return true;
4471
4370
  case GGML_OP_ADD:
4472
4371
  case GGML_OP_ADD1:
4473
- case GGML_OP_LOG:
4474
4372
  case GGML_OP_SUB:
4475
4373
  case GGML_OP_MUL:
4476
4374
  case GGML_OP_DIV:
4477
- case GGML_OP_RMS_NORM:
4478
- case GGML_OP_SCALE:
4375
+ case GGML_OP_REPEAT:
4376
+ return true;
4479
4377
  case GGML_OP_SQR:
4480
4378
  case GGML_OP_SQRT:
4481
4379
  case GGML_OP_SIN:
4482
4380
  case GGML_OP_COS:
4483
4381
  case GGML_OP_CLAMP:
4382
+ case GGML_OP_LOG:
4383
+ #if defined (GGML_SYCL_F16)
4384
+ return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
4385
+ #else
4386
+ return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4387
+ #endif
4388
+ case GGML_OP_NORM:
4389
+ case GGML_OP_RMS_NORM:
4390
+ return true;
4391
+ case GGML_OP_L2_NORM:
4392
+ case GGML_OP_GROUP_NORM:
4393
+ return ggml_is_contiguous(op->src[0]);
4394
+ case GGML_OP_SCALE:
4484
4395
  return true;
4485
4396
  case GGML_OP_CONT:
4486
4397
  return op->src[0]->type != GGML_TYPE_BF16;
@@ -4488,30 +4399,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4488
4399
  case GGML_OP_SOFT_MAX:
4489
4400
  return true;
4490
4401
  case GGML_OP_ROPE:
4491
- {
4492
- const int mode = ((const int32_t *) op->op_params)[2];
4493
- if (mode & GGML_ROPE_TYPE_MROPE) {
4494
- return false;
4495
- }
4496
- if (mode & GGML_ROPE_TYPE_VISION) {
4497
- return false;
4498
- }
4499
- return ggml_is_contiguous(op->src[0]);
4500
- }
4501
4402
  case GGML_OP_IM2COL:
4502
- // TODO: add support for the new F32 operations
4503
- return op->src[0]->type == GGML_TYPE_F16;
4403
+ return true;
4404
+ case GGML_OP_UPSCALE:
4405
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4504
4406
  case GGML_OP_POOL_2D:
4505
4407
  case GGML_OP_SUM:
4506
4408
  case GGML_OP_SUM_ROWS:
4507
4409
  case GGML_OP_ARGSORT:
4508
4410
  case GGML_OP_ACC:
4509
- case GGML_OP_GROUP_NORM:
4510
- case GGML_OP_UPSCALE:
4511
4411
  case GGML_OP_PAD:
4512
4412
  case GGML_OP_LEAKY_RELU:
4513
4413
  case GGML_OP_TIMESTEP_EMBEDDING:
4514
4414
  case GGML_OP_RWKV_WKV6:
4415
+ case GGML_OP_RWKV_WKV7:
4416
+ case GGML_OP_GATED_LINEAR_ATTN:
4515
4417
  return true;
4516
4418
  default:
4517
4419
  return false;
@@ -4586,6 +4488,7 @@ static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_bac
4586
4488
 
4587
4489
  static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
4588
4490
  GGML_UNUSED(dev);
4491
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4589
4492
 
4590
4493
  sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4591
4494
  SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
@@ -4638,10 +4541,9 @@ static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t re
4638
4541
  static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
4639
4542
  GGML_UNUSED(reg);
4640
4543
 
4641
- // TODO: update to the current function signature
4642
- //if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4643
- // return (void *)ggml_backend_sycl_split_buffer_type;
4644
- //}
4544
+ if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4545
+ return (void *)ggml_backend_sycl_split_buffer_type;
4546
+ }
4645
4547
 
4646
4548
  // SYCL doesn't support registering host memory, left here for reference
4647
4549
  // "ggml_backend_register_host_buffer"