whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -37,10 +37,20 @@
37
37
  #include "ggml-backend-impl.h"
38
38
 
39
39
  #include "ggml-sycl/backend.hpp"
40
+ #include "ggml-sycl/common.hpp"
41
+ #include "ggml-sycl/element_wise.hpp"
40
42
  #include "ggml-sycl/presets.hpp"
41
43
  #include "ggml-sycl/gemm.hpp"
44
+ #include "ggml-sycl/sycl_hw.hpp"
45
+ #include "ggml-sycl/getrows.hpp"
46
+ #include "ggml.h"
42
47
 
43
48
  static bool g_sycl_loaded = false;
49
+ int g_ggml_sycl_debug = 0;
50
+ int g_ggml_sycl_disable_optimize = 0;
51
+ int g_ggml_sycl_disable_graph = 0;
52
+ int g_ggml_sycl_disable_dnn = 0;
53
+ int g_ggml_sycl_prioritize_dmmv = 0;
44
54
 
45
55
  static ggml_sycl_device_info ggml_sycl_init() {
46
56
  ggml_sycl_device_info info = {};
@@ -54,29 +64,27 @@ static ggml_sycl_device_info ggml_sycl_init() {
54
64
  GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
55
65
 
56
66
  int64_t total_vram = 0;
57
- #if defined(GGML_SYCL_FORCE_MMQ)
58
- GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__);
59
- #else
60
- GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: no\n", __func__);
61
- #endif
62
- #if defined(SYCL_USE_XMX)
63
- GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
64
- #else
65
- GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
66
- #endif
67
- GGML_LOG_INFO("%s: found %d %s devices:\n", __func__, info.device_count, GGML_SYCL_NAME);
68
-
67
+ /* This is a bit misleading; reserved for later */
68
+ // #if defined(SYCL_USE_XMX)
69
+ // GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
70
+ // #else
71
+ // GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
72
+ // #endif
69
73
  for (int i = 0; i < info.device_count; ++i) {
70
74
  info.devices[i].vmm = 0;
71
75
  dpct::device_info prop;
76
+ sycl::device device = dpct::dev_mgr::instance().get_device(i);
77
+
72
78
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
73
- prop, dpct::dev_mgr::instance().get_device(i))));
79
+ prop, device)));
74
80
 
75
81
  info.default_tensor_split[i] = total_vram;
76
82
  total_vram += prop.get_global_mem_size();
77
83
 
78
84
  info.devices[i].cc =
79
85
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
86
+ info.devices[i].hw_info = get_device_hw_info(&device);
87
+ info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch);
80
88
 
81
89
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
82
90
  }
@@ -92,7 +100,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
92
100
  return info;
93
101
  }
94
102
 
95
- void print_device_detail(int id, sycl::device &device, std::string device_type) {
103
+ static void print_device_detail(int id, sycl::device &device, std::string device_type) {
96
104
 
97
105
  dpct::device_info prop;
98
106
  SYCL_CHECK(CHECK_TRY_ERROR(
@@ -109,13 +117,33 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
109
117
  name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
110
118
 
111
119
  auto global_mem_size = prop.get_global_mem_size()/1000000;
112
-
113
120
  GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
114
121
  name.c_str(), version.c_str(), prop.get_max_compute_units(),
115
122
  prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
116
123
  global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
117
124
  }
118
125
 
126
+ static void print_device_opt_feature(int device_count) {
127
+ GGML_LOG_INFO("SYCL Optimization Feature:\n");
128
+ GGML_LOG_INFO(
129
+ "|ID| Device Type|Reorder|\n");
130
+ GGML_LOG_INFO(
131
+ "|--|-------------------|-------|\n");
132
+ std::map<std::string, size_t> DeviceNums;
133
+ for (int id = 0; id < device_count; ++id) {
134
+ sycl::device device = dpct::dev_mgr::instance().get_device(id);
135
+ std::string backend_type = get_device_backend_and_type(device);
136
+ int type_id = DeviceNums[backend_type]++;
137
+ std::stringstream device_type;
138
+ device_type << "[" << backend_type << ":" << std::to_string(type_id)
139
+ << "]";
140
+ std::string device_type_s = device_type.str();
141
+ device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
142
+ GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
143
+ ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
144
+ }
145
+
146
+ }
119
147
  void ggml_backend_sycl_print_sycl_devices() {
120
148
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
121
149
  int device_count = dpct::dev_mgr::instance().device_count();
@@ -144,6 +172,8 @@ void ggml_backend_sycl_print_sycl_devices() {
144
172
  << "]";
145
173
  print_device_detail(id, device, device_type.str());
146
174
  }
175
+
176
+ print_device_opt_feature(device_count);
147
177
  }
148
178
 
149
179
  static inline int get_sycl_env(const char *env_name, int default_val) {
@@ -164,14 +194,36 @@ static void ggml_check_sycl() try {
164
194
  static bool initialized = false;
165
195
 
166
196
  if (!initialized) {
167
- GGML_LOG_INFO("[SYCL] call ggml_check_sycl\n");
168
197
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
169
- GGML_LOG_INFO("%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
170
-
198
+ g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
199
+ g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
200
+ g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
201
+ g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
202
+ GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
203
+ GGML_LOG_INFO("Running with Environment Variables:\n");
204
+ GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
205
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
206
+ #ifdef GGML_SYCL_GRAPH
207
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
208
+ #else
209
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
210
+ #endif
211
+ #if GGML_SYCL_DNNL
212
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
213
+ #else
214
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
215
+ #endif
216
+ GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
217
+ GGML_LOG_INFO("Build with Macros:\n");
218
+ #if defined(GGML_SYCL_FORCE_MMQ)
219
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
220
+ #else
221
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
222
+ #endif
171
223
  #if defined(GGML_SYCL_F16)
172
- GGML_LOG_INFO("%s: GGML_SYCL_F16: yes\n", __func__);
224
+ GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
173
225
  #else
174
- GGML_LOG_INFO("%s: GGML_SYCL_F16: no\n", __func__);
226
+ GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
175
227
  #endif
176
228
 
177
229
  /* NOT REMOVE, keep it for next optimize for XMX.
@@ -243,19 +295,27 @@ struct ggml_backend_sycl_buffer_context {
243
295
  void * dev_ptr = nullptr;
244
296
  queue_ptr stream;
245
297
  std::string name;
298
+ optimize_feature opt_feature;
299
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
246
300
 
247
- ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
301
+ ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
248
302
  device(device), dev_ptr(dev_ptr), stream(stream) {
249
303
  check_allow_gpu_index(device);
250
304
  name = (GGML_SYCL_NAME + std::to_string(device));
305
+ opt_feature = ggml_sycl_info().devices[device].opt_feature;
251
306
  }
252
307
 
253
-
254
308
  ~ggml_backend_sycl_buffer_context() {
255
309
  if (dev_ptr != nullptr) {
256
310
  ggml_sycl_set_device(device);
257
311
  SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
258
312
  }
313
+
314
+ //release extra used by tensors
315
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
316
+ release_extra_gpu(extra);
317
+ }
318
+
259
319
  }
260
320
  };
261
321
 
@@ -283,18 +343,22 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
283
343
  return ctx->dev_ptr;
284
344
  }
285
345
 
286
- static void
346
+ static enum ggml_status
287
347
  ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
288
348
  ggml_tensor *tensor) try {
349
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
350
+ debug_print_tensor(": tensor=", tensor, "\n");
289
351
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
290
352
 
291
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
353
+ if (tensor->view_src != NULL) {
292
354
  assert(tensor->view_src->buffer->buft == buffer->buft);
293
- tensor->backend = tensor->view_src->backend;
294
- tensor->extra = tensor->view_src->extra;
295
- return;
355
+ return GGML_STATUS_SUCCESS;
356
+ }
357
+ if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
358
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
359
+ tensor->extra = extra;
360
+ ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
296
361
  }
297
-
298
362
 
299
363
  if (ggml_is_quantized(tensor->type)) {
300
364
  // initialize padding to 0 to avoid possible NaN values
@@ -307,6 +371,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
307
371
  padded_size - original_size).wait()));
308
372
  }
309
373
  }
374
+ return GGML_STATUS_SUCCESS;
310
375
  }
311
376
  catch (sycl::exception const &exc) {
312
377
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -318,19 +383,23 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
318
383
  ggml_tensor *tensor,
319
384
  const void *data, size_t offset,
320
385
  size_t size) try {
321
-
386
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
387
+ debug_print_tensor(": tensor=", tensor);
388
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
322
389
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
323
-
324
390
  ggml_sycl_set_device(ctx->device);
325
391
  auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
326
- SYCL_CHECK(
327
- CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
328
- char* host_buf = (char*)malloc(size);
392
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
393
+ #ifndef _WIN32
394
+ // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
395
+ // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
396
+ char * host_buf = (char *) malloc(size);
329
397
  memcpy(host_buf, data, size);
330
- SYCL_CHECK(
331
- CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
332
- .wait()));
398
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
333
399
  free(host_buf);
400
+ #else
401
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
402
+ #endif
334
403
  }
335
404
  catch (sycl::exception const &exc) {
336
405
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -342,7 +411,9 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
342
411
  const ggml_tensor *tensor,
343
412
  void *data, size_t offset,
344
413
  size_t size) try {
345
-
414
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
415
+ debug_print_tensor(": tensor=", tensor);
416
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
346
417
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
347
418
 
348
419
  ggml_sycl_set_device(ctx->device);
@@ -358,7 +429,7 @@ catch (sycl::exception const &exc) {
358
429
  std::exit(1);
359
430
  }
360
431
 
361
- void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
432
+ static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
362
433
  const void *ptr_src, size_t size) {
363
434
  char *host_buf = (char *)malloc(size);
364
435
  q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
@@ -370,7 +441,12 @@ static bool
370
441
  ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
371
442
  const ggml_tensor *src,
372
443
  ggml_tensor *dst) try {
373
- if (ggml_backend_buffer_is_sycl(src->buffer)) {
444
+ bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);
445
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
446
+ debug_print_tensor(": dst=", dst);
447
+ debug_print_tensor(" src=", src);
448
+ GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
449
+ if (is_cpy_supported) {
374
450
  ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
375
451
  ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
376
452
 
@@ -427,7 +503,8 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
427
503
 
428
504
  static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
429
505
  uint8_t value) try {
430
- ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
506
+ GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size);
507
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
431
508
 
432
509
  ggml_sycl_set_device(ctx->device);
433
510
  queue_ptr stream = ctx->stream;
@@ -444,16 +521,51 @@ catch (sycl::exception const &exc) {
444
521
  std::exit(1);
445
522
  }
446
523
 
524
+ static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
525
+ size_t offset, size_t size) {
526
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
527
+ debug_print_tensor(": tensor=", tensor);
528
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
529
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
530
+ SYCL_CHECK(ggml_sycl_set_device(ctx->device));
531
+ auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
532
+ if (size == 0) {
533
+ return; // Nothing to do
534
+ }
535
+ if (tensor->data == nullptr) {
536
+ GGML_ABORT("Error: Tensor data pointer is null.\n");
537
+ }
538
+ void * target_ptr = static_cast<char *>(tensor->data) + offset;
539
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
540
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
541
+ }
542
+
543
+ static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
544
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
545
+ if (buffer == nullptr) {
546
+ return;
547
+ }
548
+
549
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
550
+
551
+ if (ctx != nullptr) {
552
+ for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
553
+ release_extra_gpu(extra);
554
+ }
555
+ ctx->tensor_extras.clear(); // reset the tensor_extras vector
556
+ }
557
+ }
558
+
447
559
  static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
448
560
  /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
449
561
  /* .get_base = */ ggml_backend_sycl_buffer_get_base,
450
562
  /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
451
- /* .memset_tensor = */ NULL,
563
+ /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
452
564
  /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
453
565
  /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
454
566
  /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
455
567
  /* .clear = */ ggml_backend_sycl_buffer_clear,
456
- /* .reset = */ NULL,
568
+ /* .reset = */ ggml_backend_sycl_buffer_reset,
457
569
  };
458
570
 
459
571
  // sycl buffer type
@@ -534,12 +646,11 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
534
646
  static std::mutex mutex;
535
647
  std::lock_guard<std::mutex> lock(mutex);
536
648
 
537
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
538
649
 
539
650
  auto dev_count = ggml_backend_sycl_get_device_count();
540
651
 
541
652
  if (device>=dev_count or device<0) {
542
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
653
+ GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
543
654
  device, dev_count-1);
544
655
  GGML_ASSERT(device<dev_count);
545
656
  }
@@ -562,12 +673,12 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
562
673
  return &ggml_backend_sycl_buffer_types[device];
563
674
  }
564
675
 
565
- ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
676
+ static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
566
677
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
567
678
 
568
679
  int device = ctx->device;
569
680
  if (device>=ggml_sycl_info().device_count or device<0) {
570
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
681
+ GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
571
682
  device, ggml_sycl_info().device_count-1);
572
683
  GGML_ASSERT(device<ggml_sycl_info().device_count);
573
684
  }
@@ -664,32 +775,7 @@ struct ggml_backend_sycl_split_buffer_type_context {
664
775
  struct ggml_backend_sycl_split_buffer_context {
665
776
  ~ggml_backend_sycl_split_buffer_context() try {
666
777
  for (ggml_tensor_extra_gpu * extra : tensor_extras) {
667
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
668
- for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
669
- if (extra->events[i][is] != nullptr) {
670
- /*
671
- DPCT1009:206: SYCL uses exceptions to report errors and
672
- does not use the error codes. The original code was
673
- commented out and a warning string was inserted. You
674
- need to rewrite this code.
675
- */
676
- SYCL_CHECK(CHECK_TRY_ERROR(
677
- dpct::destroy_event(extra->events[i][is])));
678
- }
679
- }
680
- if (extra->data_device[i] != nullptr) {
681
- /*
682
- DPCT1009:207: SYCL uses exceptions to report errors and does
683
- not use the error codes. The original code was commented out
684
- and a warning string was inserted. You need to rewrite this
685
- code.
686
- */
687
- ggml_sycl_set_device(i);
688
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
689
- extra->data_device[i], *(streams[i]))));
690
- }
691
- }
692
- delete extra;
778
+ release_extra_gpu(extra, streams);
693
779
  }
694
780
  }
695
781
  catch (sycl::exception const &exc) {
@@ -714,9 +800,11 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff
714
800
  GGML_UNUSED(buffer);
715
801
  }
716
802
 
717
- static void
803
+ static enum ggml_status
718
804
  ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
719
805
  ggml_tensor *tensor) try {
806
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
807
+ debug_print_tensor(": tensor=", tensor, "\n");
720
808
  GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
721
809
 
722
810
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
@@ -727,7 +815,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
727
815
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
728
816
 
729
817
  ctx->tensor_extras.push_back(extra);
730
- ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
818
+ ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
731
819
 
732
820
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
733
821
  int64_t row_low, row_high;
@@ -746,7 +834,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
746
834
  size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
747
835
  }
748
836
 
749
- // FIXME: do not crash if cudaMalloc fails
837
+ // FIXME: do not crash if SYCL Buffer alloc fails
750
838
  // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
751
839
  ggml_sycl_set_device(i);
752
840
  const queue_ptr stream = ctx->streams[i];
@@ -788,8 +876,8 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
788
876
  CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
789
877
  }
790
878
  }
791
- tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
792
879
  tensor->extra = extra;
880
+ return GGML_STATUS_SUCCESS;
793
881
  }
794
882
  catch (sycl::exception const &exc) {
795
883
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -801,6 +889,9 @@ static void
801
889
  ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
802
890
  ggml_tensor *tensor, const void *data,
803
891
  size_t offset, size_t size) try {
892
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
893
+ debug_print_tensor(": tensor=", tensor);
894
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
804
895
  // split tensors must always be set in their entirety at once
805
896
  GGML_ASSERT(offset == 0);
806
897
  GGML_ASSERT(size == ggml_nbytes(tensor));
@@ -854,6 +945,9 @@ static void
854
945
  ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
855
946
  const ggml_tensor *tensor, void *data,
856
947
  size_t offset, size_t size) try {
948
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
949
+ debug_print_tensor(": tensor=", tensor);
950
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
857
951
  // split tensors must always be set in their entirety at once
858
952
  GGML_ASSERT(offset == 0);
859
953
  GGML_ASSERT(size == ggml_nbytes(tensor));
@@ -1178,6 +1272,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
1178
1272
  }
1179
1273
  };
1180
1274
 
1275
+ struct ggml_sycl_pool_host : public ggml_sycl_pool {
1276
+ queue_ptr qptr;
1277
+ int device;
1278
+
1279
+ inline static int counter{ 0 };
1280
+
1281
+ struct ggml_sycl_buffer {
1282
+ void * ptr = nullptr;
1283
+ size_t size = 0;
1284
+ };
1285
+
1286
+ // Set arbitrarly to 64
1287
+ static constexpr int MAX_POOL_SIZE{ 64 };
1288
+ std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1289
+ size_t pool_size = 0;
1290
+
1291
+ explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1292
+
1293
+ ~ggml_sycl_pool_host() {
1294
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1295
+ ggml_sycl_buffer & b = buffer_pool[i];
1296
+ if (b.ptr != nullptr) {
1297
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1298
+ b.ptr = nullptr;
1299
+ pool_size -= b.size;
1300
+ b.size = 0;
1301
+ }
1302
+ }
1303
+ counter = 0;
1304
+ }
1305
+
1306
+ void * alloc(size_t size, size_t * actual_size) override {
1307
+ if (counter == MAX_POOL_SIZE) {
1308
+ ggml_sycl_buffer b = buffer_pool[0];
1309
+ void * ptr = b.ptr;
1310
+ *actual_size = b.size;
1311
+ counter = 1;
1312
+ return ptr;
1313
+ }
1314
+ ggml_sycl_buffer & b = buffer_pool[counter];
1315
+
1316
+ if (b.ptr == nullptr) {
1317
+ void * ptr;
1318
+
1319
+ SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1320
+ if (!ptr) {
1321
+ GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1322
+ return nullptr;
1323
+ }
1324
+ pool_size += size;
1325
+ *actual_size = size;
1326
+ counter = counter + 1;
1327
+ return ptr;
1328
+ } else {
1329
+ ++counter;
1330
+ b.size = size;
1331
+ return b.ptr;
1332
+ }
1333
+ }
1334
+
1335
+ void free(void * ptr, size_t size) override {
1336
+ // if the pool is not completed add the pointer to it in place of the first nullptr found.
1337
+ // Otherwise do nothing, pointers will be freed once the pool is deallocated.
1338
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1339
+ ggml_sycl_buffer & b = buffer_pool[i];
1340
+ if (b.ptr == nullptr) {
1341
+ b.ptr = ptr;
1342
+ b.size = size;
1343
+ return;
1344
+ }
1345
+ }
1346
+ }
1347
+ };
1348
+
1349
+ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1350
+ // return pool for the host to speed up memory management
1351
+ return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1352
+ }
1353
+
1181
1354
  std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
1182
1355
  // TBD: NO VMM support
1183
1356
  // if (ggml_sycl_info().devices[device].vmm) {
@@ -1190,9 +1363,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
1190
1363
  // struct ggml_sycl_pool_vmm : public ggml_sycl_pool
1191
1364
 
1192
1365
  /// kernels
1193
-
1194
- typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
1195
- typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
1196
1366
  typedef void (*ggml_sycl_op_mul_mat_t)(
1197
1367
  ggml_backend_sycl_context & ctx,
1198
1368
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -1264,83 +1434,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1264
1434
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1265
1435
  }
1266
1436
 
1267
- template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1268
- static void k_get_rows(
1269
- const void * src0, const int32_t * src1, dst_t * dst,
1270
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1271
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1272
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1273
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1274
- size_t s10, size_t s11, size_t s12,
1275
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1276
-
1277
- const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1278
- item_ct1.get_local_id(2)) *
1279
- 2;
1280
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1281
- item_ct1.get_local_id(1);
1282
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1283
- item_ct1.get_local_id(0)) /
1284
- ne12;
1285
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1286
- item_ct1.get_local_id(0)) %
1287
- ne12;
1288
-
1289
- if (i00 >= ne00) {
1290
- return;
1291
- }
1292
-
1293
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1294
-
1295
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1296
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
1297
-
1298
- const int ib = i00/qk; // block index
1299
- const int iqs = (i00%qk)/qr; // quant index
1300
- const int iybs = i00 - i00%qk; // dst block start index
1301
- const int y_offset = qr == 1 ? 1 : qk/2;
1302
-
1303
- // dequantize
1304
- dfloat2 v;
1305
- dequantize_kernel(src0_row, ib, iqs, v);
1306
-
1307
- dst_row[iybs + iqs + 0] = v.x();
1308
- dst_row[iybs + iqs + y_offset] = v.y();
1309
- }
1310
-
1311
- template<typename src0_t, typename dst_t>
1312
- static void k_get_rows_float(
1313
- const src0_t * src0, const int32_t * src1, dst_t * dst,
1314
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1315
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1316
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1317
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1318
- size_t s10, size_t s11, size_t s12,
1319
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1320
-
1321
- const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1322
- item_ct1.get_local_id(2);
1323
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1324
- item_ct1.get_local_id(1);
1325
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1326
- item_ct1.get_local_id(0)) /
1327
- ne12;
1328
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1329
- item_ct1.get_local_id(0)) %
1330
- ne12;
1331
-
1332
- if (i00 >= ne00) {
1333
- return;
1334
- }
1335
-
1336
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1337
-
1338
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1339
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1340
-
1341
- dst_row[i00] = src0_row[i00];
1342
- }
1343
-
1344
1437
  static void mul_mat_p021_f16_f32(
1345
1438
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1346
1439
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1451,193 +1544,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1451
1544
  }
1452
1545
  }
1453
1546
 
1454
- static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1455
- const float * xi = (const float *) cxi;
1456
- float * dsti = (float *) cdsti;
1457
-
1458
- *dsti = *xi;
1459
- }
1460
-
1461
- static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1462
- const float * xi = (const float *) cxi;
1463
- sycl::half *dsti = (sycl::half *)cdsti;
1464
-
1465
- *dsti = sycl::vec<float, 1>(*xi)
1466
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
1467
- }
1468
-
1469
- static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
1470
- const sycl::half *xi = (const sycl::half *)cxi;
1471
- sycl::half *dsti = (sycl::half *)cdsti;
1472
-
1473
- *dsti = *xi;
1474
- }
1475
-
1476
- static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
1477
- const sycl::half *xi = (const sycl::half *)cxi;
1478
- float * dsti = (float *) cdsti;
1479
-
1480
- *dsti = *xi;
1481
- }
1482
-
1483
- static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
1484
- const int16_t *xi = (const int16_t *)cxi;
1485
- int16_t *dsti = (int16_t *)cdsti;
1486
-
1487
- *dsti = *xi;
1488
- }
1489
-
1490
- static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
1491
- const int32_t *xi = (const int32_t *)cxi;
1492
- int32_t *dsti = (int32_t *)cdsti;
1493
-
1494
- *dsti = *xi;
1495
- }
1496
-
1497
- template <cpy_kernel_t cpy_1>
1498
- static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
1499
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1500
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1501
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1502
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1503
- item_ct1.get_local_id(2);
1504
-
1505
- if (i >= ne) {
1506
- return;
1507
- }
1508
-
1509
- // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
1510
- // then combine those indices with the corresponding byte offsets to get the total offsets
1511
- const int i03 = i/(ne00 * ne01 * ne02);
1512
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1513
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1514
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1515
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1516
-
1517
- const int i13 = i/(ne10 * ne11 * ne12);
1518
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1519
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1520
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1521
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
1522
-
1523
- cpy_1(cx + x_offset, cdst + dst_offset);
1524
- }
1525
-
1526
- static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
1527
- const float * xi = (const float *) cxi;
1528
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
1529
-
1530
- float amax = 0.0f; // absolute max
1531
-
1532
- for (int j = 0; j < QK8_0; j++) {
1533
- const float v = xi[j];
1534
- amax = sycl::fmax(amax, sycl::fabs((float)v));
1535
- }
1536
-
1537
- const float d = amax / ((1 << 7) - 1);
1538
- const float id = d ? 1.0f/d : 0.0f;
1539
-
1540
- dsti->d = d;
1541
-
1542
- for (int j = 0; j < QK8_0; ++j) {
1543
- const float x0 = xi[j]*id;
1544
-
1545
- dsti->qs[j] = sycl::round((float)x0);
1546
- }
1547
- }
1548
-
1549
- static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
1550
- const float * xi = (const float *) cxi;
1551
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
1552
-
1553
- float amax = 0.0f;
1554
- float vmax = 0.0f;
1555
-
1556
- for (int j = 0; j < QK4_0; ++j) {
1557
- const float v = xi[j];
1558
- if (amax < sycl::fabs((float)v)) {
1559
- amax = sycl::fabs((float)v);
1560
- vmax = v;
1561
- }
1562
- }
1563
-
1564
- const float d = vmax / -8;
1565
- const float id = d ? 1.0f/d : 0.0f;
1566
-
1567
- dsti->d = d;
1568
-
1569
- for (int j = 0; j < QK4_0/2; ++j) {
1570
- const float x0 = xi[0 + j]*id;
1571
- const float x1 = xi[QK4_0/2 + j]*id;
1572
-
1573
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
1574
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
1575
-
1576
- dsti->qs[j] = xi0;
1577
- dsti->qs[j] |= xi1 << 4;
1578
- }
1579
- }
1580
-
1581
- static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
1582
- const float * xi = (const float *) cxi;
1583
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
1584
-
1585
- float vmin = FLT_MAX;
1586
- float vmax = -FLT_MAX;
1587
-
1588
- for (int j = 0; j < QK4_1; ++j) {
1589
- const float v = xi[j];
1590
-
1591
- if (v < vmin) vmin = v;
1592
- if (v > vmax) vmax = v;
1593
- }
1594
-
1595
- const float d = (vmax - vmin) / ((1 << 4) - 1);
1596
- const float id = d ? 1.0f/d : 0.0f;
1597
-
1598
- dsti->dm.x() = d;
1599
- dsti->dm.y() = vmin;
1600
-
1601
- for (int j = 0; j < QK4_1/2; ++j) {
1602
- const float x0 = (xi[0 + j] - vmin)*id;
1603
- const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
1604
-
1605
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
1606
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
1607
-
1608
- dsti->qs[j] = xi0;
1609
- dsti->qs[j] |= xi1 << 4;
1610
- }
1611
- }
1612
-
1613
- template <cpy_kernel_t cpy_blck, int qk>
1614
- static void cpy_f32_q(const char * cx, char * cdst, const int ne,
1615
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1616
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1617
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1618
- const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1619
- item_ct1.get_local_id(2)) *
1620
- qk;
1621
-
1622
- if (i >= ne) {
1623
- return;
1624
- }
1625
-
1626
- const int i03 = i/(ne00 * ne01 * ne02);
1627
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1628
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1629
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1630
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1631
-
1632
- const int i13 = i/(ne10 * ne11 * ne12);
1633
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1634
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1635
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1636
- const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
1637
-
1638
- cpy_blck(cx + x_offset, cdst + dst_offset);
1639
- }
1640
-
1641
1547
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1642
1548
  const sycl::nd_item<3> &item_ct1) {
1643
1549
  const int row = item_ct1.get_group(1);
@@ -1749,17 +1655,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
1749
1655
  dst[i] = scale * x[i];
1750
1656
  }
1751
1657
 
1752
- static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
1753
- const sycl::nd_item<3> &item_ct1) {
1754
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1755
- item_ct1.get_local_id(2);
1756
-
1757
- if (i >= k) {
1758
- return;
1759
- }
1760
-
1761
- dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
1762
- }
1763
1658
 
1764
1659
  template <typename Ti, typename To>
1765
1660
  static void pool2d_nchw_kernel(
@@ -1823,81 +1718,6 @@ static void pool2d_nchw_kernel(
1823
1718
  o_ptr[cur_oh * ow + cur_ow] = res;
1824
1719
  }
1825
1720
 
1826
- template <int qk, int qr, dequantize_kernel_t dq>
1827
- static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
1828
- ggml_tensor *dst, const void *src0_dd,
1829
- const int32_t *src1_dd, float *dst_dd,
1830
- queue_ptr stream) {
1831
-
1832
- GGML_TENSOR_BINARY_OP_LOCALS
1833
-
1834
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1835
- const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
1836
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1837
-
1838
- // strides in elements
1839
- //const size_t s0 = nb0 / ggml_element_size(dst);
1840
- const size_t s1 = nb1 / ggml_element_size(dst);
1841
- const size_t s2 = nb2 / ggml_element_size(dst);
1842
- const size_t s3 = nb3 / ggml_element_size(dst);
1843
-
1844
- const size_t s10 = nb10 / ggml_element_size(src1);
1845
- const size_t s11 = nb11 / ggml_element_size(src1);
1846
- const size_t s12 = nb12 / ggml_element_size(src1);
1847
- //const size_t s13 = nb13 / ggml_element_size(src1);
1848
-
1849
- GGML_ASSERT(ne00 % 2 == 0);
1850
-
1851
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1852
- [=](sycl::nd_item<3> item_ct1) {
1853
- k_get_rows<qk, qr, dq>(
1854
- src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1855
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1856
- });
1857
-
1858
- GGML_UNUSED(dst);
1859
- GGML_UNUSED(ctx);
1860
- }
1861
-
1862
- template <typename src0_t>
1863
- static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1864
- const ggml_tensor *src1, ggml_tensor *dst,
1865
- const src0_t *src0_dd, const int32_t *src1_dd,
1866
- float *dst_dd, queue_ptr stream) {
1867
-
1868
- GGML_TENSOR_BINARY_OP_LOCALS
1869
-
1870
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1871
- const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
1872
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1873
-
1874
- // strides in elements
1875
- //const size_t s0 = nb0 / ggml_element_size(dst);
1876
- const size_t s1 = nb1 / ggml_element_size(dst);
1877
- const size_t s2 = nb2 / ggml_element_size(dst);
1878
- const size_t s3 = nb3 / ggml_element_size(dst);
1879
-
1880
- const size_t s10 = nb10 / ggml_element_size(src1);
1881
- const size_t s11 = nb11 / ggml_element_size(src1);
1882
- const size_t s12 = nb12 / ggml_element_size(src1);
1883
- //const size_t s13 = nb13 / ggml_element_size(src1);
1884
-
1885
- {
1886
- dpct::has_capability_or_fail(stream->get_device(),
1887
- {sycl::aspect::fp16});
1888
-
1889
- stream->parallel_for(
1890
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1891
- [=](sycl::nd_item<3> item_ct1) {
1892
- k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1893
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1894
- });
1895
- }
1896
-
1897
- GGML_UNUSED(dst);
1898
- GGML_UNUSED(ctx);
1899
- }
1900
-
1901
1721
  static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1902
1722
  const int ky, const int kx_padded,
1903
1723
  queue_ptr stream) {
@@ -1912,7 +1732,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1912
1732
 
1913
1733
  stream->parallel_for(
1914
1734
  sycl::nd_range<3>(num_blocks * block_size, block_size),
1915
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1735
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1916
1736
  quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1917
1737
  });
1918
1738
  }
@@ -1933,7 +1753,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1933
1753
 
1934
1754
  stream->parallel_for(
1935
1755
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1936
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1756
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1937
1757
  mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
1938
1758
  nchannels_y, item_ct1);
1939
1759
  });
@@ -1953,7 +1773,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1953
1773
 
1954
1774
  stream->parallel_for(
1955
1775
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1956
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1776
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1957
1777
  mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1958
1778
  row_stride_x, channel_stride_x,
1959
1779
  nchannels_y / nchannels_x, item_ct1);
@@ -1961,231 +1781,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1961
1781
  }
1962
1782
  }
1963
1783
 
1964
- static void
1965
- ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
1966
- const int ne01, const int ne02, const int nb00,
1967
- const int nb01, const int nb02, const int nb03,
1968
- const int ne10, const int ne11, const int ne12,
1969
- const int nb10, const int nb11, const int nb12,
1970
- const int nb13, queue_ptr stream) {
1971
-
1972
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
1973
- {
1974
- dpct::has_capability_or_fail(stream->get_device(),
1975
- {sycl::aspect::fp16});
1976
-
1977
- stream->parallel_for(
1978
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1979
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
1980
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
1981
- [=](sycl::nd_item<3> item_ct1) {
1982
- cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
1983
- nb01, nb02, nb03, ne10, ne11, ne12,
1984
- nb10, nb11, nb12, nb13, item_ct1);
1985
- });
1986
- }
1987
- }
1988
-
1989
- static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
1990
- const int ne00, const int ne01,
1991
- const int ne02, const int nb00,
1992
- const int nb01, const int nb02,
1993
- const int nb03, const int ne10,
1994
- const int ne11, const int ne12,
1995
- const int nb10, const int nb11,
1996
- const int nb12, const int nb13,
1997
- queue_ptr stream) {
1998
-
1999
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2000
- {
2001
- dpct::has_capability_or_fail(stream->get_device(),
2002
- {sycl::aspect::fp16});
2003
-
2004
- stream->parallel_for(
2005
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2006
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2007
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2008
- [=](sycl::nd_item<3> item_ct1) {
2009
- cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2010
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2011
- item_ct1);
2012
- });
2013
- }
2014
- }
2015
-
2016
- static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
2017
- const int ne00, const int ne01,
2018
- const int ne02, const int nb00,
2019
- const int nb01, const int nb02,
2020
- const int nb03, const int ne10,
2021
- const int ne11, const int ne12,
2022
- const int nb10, const int nb11,
2023
- const int nb12, const int nb13,
2024
- queue_ptr stream) {
2025
-
2026
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2027
- {
2028
- dpct::has_capability_or_fail(stream->get_device(),
2029
- {sycl::aspect::fp16});
2030
-
2031
- stream->parallel_for(
2032
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2033
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2034
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2035
- [=](sycl::nd_item<3> item_ct1) {
2036
- cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2037
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2038
- item_ct1);
2039
- });
2040
- }
2041
- }
2042
-
2043
- static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
2044
- const int ne00, const int ne01,
2045
- const int ne02, const int nb00,
2046
- const int nb01, const int nb02,
2047
- const int nb03, const int ne10,
2048
- const int ne11, const int ne12,
2049
- const int nb10, const int nb11,
2050
- const int nb12, const int nb13,
2051
- queue_ptr stream) {
2052
-
2053
- GGML_ASSERT(ne % QK8_0 == 0);
2054
- const int num_blocks = ne / QK8_0;
2055
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2056
- sycl::range<3>(1, 1, 1)),
2057
- [=](sycl::nd_item<3> item_ct1) {
2058
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
2059
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2060
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2061
- item_ct1);
2062
- });
2063
- }
2064
-
2065
- static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
2066
- const int ne00, const int ne01,
2067
- const int ne02, const int nb00,
2068
- const int nb01, const int nb02,
2069
- const int nb03, const int ne10,
2070
- const int ne11, const int ne12,
2071
- const int nb10, const int nb11,
2072
- const int nb12, const int nb13,
2073
- queue_ptr stream) {
2074
-
2075
- GGML_ASSERT(ne % QK4_0 == 0);
2076
- const int num_blocks = ne / QK4_0;
2077
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2078
- sycl::range<3>(1, 1, 1)),
2079
- [=](sycl::nd_item<3> item_ct1) {
2080
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
2081
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2082
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2083
- item_ct1);
2084
- });
2085
- }
2086
-
2087
- static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
2088
- const int ne00, const int ne01,
2089
- const int ne02, const int nb00,
2090
- const int nb01, const int nb02,
2091
- const int nb03, const int ne10,
2092
- const int ne11, const int ne12,
2093
- const int nb10, const int nb11,
2094
- const int nb12, const int nb13,
2095
- queue_ptr stream) {
2096
-
2097
- GGML_ASSERT(ne % QK4_1 == 0);
2098
- const int num_blocks = ne / QK4_1;
2099
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2100
- sycl::range<3>(1, 1, 1)),
2101
- [=](sycl::nd_item<3> item_ct1) {
2102
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
2103
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2104
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2105
- item_ct1);
2106
- });
2107
- }
2108
-
2109
- static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
2110
- const int ne00, const int ne01,
2111
- const int ne02, const int nb00,
2112
- const int nb01, const int nb02,
2113
- const int nb03, const int ne10,
2114
- const int ne11, const int ne12,
2115
- const int nb10, const int nb11,
2116
- const int nb12, const int nb13,
2117
- queue_ptr stream) {
2118
-
2119
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2120
- {
2121
- dpct::has_capability_or_fail(stream->get_device(),
2122
- {sycl::aspect::fp16});
2123
-
2124
- stream->parallel_for(
2125
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2126
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2127
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2128
- [=](sycl::nd_item<3> item_ct1) {
2129
- cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2130
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2131
- item_ct1);
2132
- });
2133
- }
2134
- }
2135
-
2136
- static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
2137
- const int ne00, const int ne01,
2138
- const int ne02, const int nb00,
2139
- const int nb01, const int nb02,
2140
- const int nb03, const int ne10,
2141
- const int ne11, const int ne12,
2142
- const int nb10, const int nb11,
2143
- const int nb12, const int nb13,
2144
- queue_ptr stream) {
2145
-
2146
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2147
- {
2148
- // dpct::has_capability_or_fail(stream->get_device(),
2149
- // {sycl::aspect::fp16});
2150
-
2151
- stream->parallel_for(
2152
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2153
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2154
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2155
- [=](sycl::nd_item<3> item_ct1) {
2156
- cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2157
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2158
- item_ct1);
2159
- });
2160
- }
2161
- }
2162
1784
 
2163
- static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
2164
- const int ne00, const int ne01,
2165
- const int ne02, const int nb00,
2166
- const int nb01, const int nb02,
2167
- const int nb03, const int ne10,
2168
- const int ne11, const int ne12,
2169
- const int nb10, const int nb11,
2170
- const int nb12, const int nb13,
2171
- queue_ptr stream) {
2172
-
2173
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2174
- {
2175
- // dpct::has_capability_or_fail(stream->get_device(),
2176
- // {sycl::aspect::fp16});
2177
-
2178
- stream->parallel_for(
2179
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2180
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2181
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2182
- [=](sycl::nd_item<3> item_ct1) {
2183
- cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2184
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2185
- item_ct1);
2186
- });
2187
- }
2188
- }
2189
1785
 
2190
1786
  static void scale_f32_sycl(const float *x, float *dst, const float scale,
2191
1787
  const int k, queue_ptr stream) {
@@ -2199,18 +1795,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
2199
1795
  });
2200
1796
  }
2201
1797
 
2202
- static void clamp_f32_sycl(const float *x, float *dst, const float min,
2203
- const float max, const int k,
2204
- queue_ptr stream) {
2205
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
2206
- stream->parallel_for(
2207
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2208
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
2209
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
2210
- [=](sycl::nd_item<3> item_ct1) {
2211
- clamp_f32(x, dst, min, max, k, item_ct1);
2212
- });
2213
- }
2214
1798
 
2215
1799
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
2216
1800
  const int nrows, queue_ptr stream) {
@@ -2218,7 +1802,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
2218
1802
  const sycl::range<3> block_nums(1, nrows, 1);
2219
1803
  stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
2220
1804
  [=](sycl::nd_item<3> item_ct1)
2221
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1805
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2222
1806
  k_sum_rows_f32(x, dst, ncols, item_ct1);
2223
1807
  });
2224
1808
  }
@@ -2349,12 +1933,22 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
2349
1933
 
2350
1934
  dpct::memcpy_direction kind;
2351
1935
  char * src_ptr;
2352
- if (src->backend == GGML_BACKEND_TYPE_CPU) {
1936
+ if (ggml_backend_buffer_is_host(src->buffer)) {
2353
1937
  kind = dpct::host_to_device;
1938
+ //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
2354
1939
  src_ptr = (char *) src->data;
2355
1940
  // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
2356
- } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
2357
- GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
1941
+ } else if (ggml_backend_buffer_is_sycl(src->buffer)) {
1942
+ // If buffer is a SYCL buffer
1943
+ //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
1944
+ kind = dpct::device_to_device;
1945
+ src_ptr = (char *) src->data;
1946
+ } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
1947
+ /*
1948
+ If buffer is a SYCL split buffer
1949
+ */
1950
+ //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
1951
+ GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
2358
1952
  kind = dpct::device_to_device;
2359
1953
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
2360
1954
  int id;
@@ -2411,65 +2005,6 @@ catch (sycl::exception const &exc) {
2411
2005
  std::exit(1);
2412
2006
  }
2413
2007
 
2414
- static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2415
- const ggml_tensor *src1, ggml_tensor *dst,
2416
- const float *src0_d, const float *src1_d,
2417
- float *dst_d, const queue_ptr &stream) {
2418
-
2419
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
2420
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2421
-
2422
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2423
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
2424
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
2425
-
2426
- const int32_t * src1_i32 = (const int32_t *) src1_d;
2427
-
2428
- switch (src0->type) {
2429
- case GGML_TYPE_F16:
2430
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
2431
- src1_i32, dst_d, stream);
2432
- break;
2433
- case GGML_TYPE_F32:
2434
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2435
- break;
2436
- case GGML_TYPE_Q4_0:
2437
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2438
- break;
2439
- case GGML_TYPE_Q4_1:
2440
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2441
- break;
2442
- case GGML_TYPE_Q5_0:
2443
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2444
- break;
2445
- case GGML_TYPE_Q5_1:
2446
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2447
- break;
2448
- case GGML_TYPE_Q8_0:
2449
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2450
- break;
2451
- default:
2452
- // TODO: k-quants
2453
- GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
2454
- GGML_ABORT("fatal error");
2455
- break;
2456
- }
2457
- }
2458
-
2459
-
2460
- static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2461
- const ggml_tensor *src1, ggml_tensor *dst,
2462
- const float *src0_d, const float *src1_d,
2463
- float *dst_d,
2464
- const queue_ptr &main_stream) {
2465
-
2466
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
2467
-
2468
- GGML_UNUSED(src1);
2469
- GGML_UNUSED(src1_d);
2470
- }
2471
-
2472
-
2473
2008
  inline void ggml_sycl_op_mul_mat_sycl(
2474
2009
  ggml_backend_sycl_context & ctx,
2475
2010
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -2484,33 +2019,31 @@ inline void ggml_sycl_op_mul_mat_sycl(
2484
2019
 
2485
2020
  const int64_t ne00 = src0->ne[0];
2486
2021
  const int64_t ne10 = src1->ne[0];
2487
-
2022
+ GGML_ASSERT(ne00 == ne10);
2488
2023
 
2489
2024
  const int64_t row_diff = row_high - row_low;
2490
2025
 
2491
2026
  int id;
2492
2027
  SYCL_CHECK(
2493
2028
  CHECK_TRY_ERROR(id = get_current_device_id()));
2494
- #if !GGML_SYCL_DNNL
2495
- const int64_t ne0 = dst->ne[0];
2029
+
2030
+ const int64_t ne0 = dst->ne[0]; // used by MKL only
2496
2031
  // the main device has a larger memory buffer to hold the results from all GPUs
2497
2032
  // ldc == nrows of the matrix that cuBLAS writes into
2498
- int ldc = id == ctx.device ? ne0 : row_diff;
2499
- #endif
2033
+ int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
2500
2034
 
2501
2035
  #ifdef GGML_SYCL_F16
2502
2036
  bool use_fp16 = true; // TODO(Yu) SYCL capability check
2503
2037
  #else
2504
2038
  bool use_fp16 = false;
2505
2039
  #endif
2506
- if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
2507
- use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
2508
- dst->op_params[0] == GGML_PREC_DEFAULT) {
2509
-
2510
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
2040
+ if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&
2041
+ row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
2511
2042
  ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
2512
2043
  if (src0->type != GGML_TYPE_F16) {
2513
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
2044
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2045
+ " : converting src0 to fp16");
2046
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
2514
2047
  GGML_ASSERT(to_fp16_sycl != nullptr);
2515
2048
  size_t ne = row_diff*ne00;
2516
2049
  src0_as_f16.alloc(ne);
@@ -2522,7 +2055,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
2522
2055
 
2523
2056
  ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
2524
2057
  if (src1->type != GGML_TYPE_F16) {
2525
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2058
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
2059
+ " : converting src1 to fp16");
2060
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2526
2061
  GGML_ASSERT(to_fp16_sycl != nullptr);
2527
2062
  size_t ne = src1_ncols*ne10;
2528
2063
  src1_as_f16.alloc(ne);
@@ -2533,38 +2068,48 @@ inline void ggml_sycl_op_mul_mat_sycl(
2533
2068
  : src1_as_f16.get();
2534
2069
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2535
2070
 
2536
- #if !GGML_SYCL_DNNL
2537
- const sycl::half alpha_f16 = 1.0f;
2538
- const sycl::half beta_f16 = 0.0f;
2539
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2540
- *stream, oneapi::mkl::transpose::trans,
2541
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2542
- &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2543
- src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2544
- dst_f16.get(), dpct::library_data_t::real_half, ldc,
2545
- dpct::library_data_t::real_half)));
2546
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2547
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2548
- #else
2549
- auto dnnl_stream = ctx.stream_dnnl(stream);
2550
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2551
- src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2552
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2553
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2071
+ #if GGML_SYCL_DNNL
2072
+ if (!g_ggml_sycl_disable_dnn) {
2073
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2074
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2075
+ dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2076
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2077
+ " : converting dst to fp32");
2078
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2079
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2080
+ }
2081
+ else
2554
2082
  #endif
2555
- }
2556
- else {
2557
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
2083
+ {
2084
+ const sycl::half alpha_f16 = 1.0f;
2085
+ const sycl::half beta_f16 = 0.0f;
2086
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2087
+ *stream, oneapi::math::transpose::trans,
2088
+ oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2089
+ &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2090
+ src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2091
+ dst_f16.get(), dpct::library_data_t::real_half, ldc,
2092
+ dpct::library_data_t::real_half)));
2093
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2094
+ " : converting dst to fp32");
2095
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2096
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2097
+ }
2098
+ } else {
2558
2099
  ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
2559
2100
  ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
2560
2101
  if (src0->type != GGML_TYPE_F32) {
2561
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
2102
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2103
+ " : converting src0 to fp32");
2104
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
2562
2105
  GGML_ASSERT(to_fp32_sycl != nullptr);
2563
2106
  src0_ddq_as_f32.alloc(row_diff*ne00);
2564
2107
  to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
2565
2108
  }
2566
2109
  if (src1->type != GGML_TYPE_F32) {
2567
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
2110
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2111
+ " : converting src1 to fp32");
2112
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
2568
2113
  GGML_ASSERT(to_fp32_sycl != nullptr);
2569
2114
  src1_ddq_as_f32.alloc(src1_ncols*ne10);
2570
2115
  to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
@@ -2572,25 +2117,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
2572
2117
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2573
2118
  const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2574
2119
 
2575
- #if !GGML_SYCL_DNNL
2576
- const float alpha = 1.0f;
2577
- const float beta = 0.0f;
2578
- # ifdef GGML_SYCL_NVIDIA
2579
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2580
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2581
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2582
- ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2583
- # else
2584
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2585
- *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2586
- dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2587
- dst_dd_i, ldc)));
2588
- # endif
2589
- #else
2590
- auto dnnl_stream = ctx.stream_dnnl(stream);
2591
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2592
- src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2120
+ #if GGML_SYCL_DNNL
2121
+ if (!g_ggml_sycl_disable_dnn) {
2122
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2123
+ DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2124
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2125
+ }
2126
+ else
2593
2127
  #endif
2128
+ {
2129
+ const float alpha = 1.0f;
2130
+ const float beta = 0.0f;
2131
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2132
+ get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2133
+ src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2134
+ dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2135
+ }
2594
2136
  }
2595
2137
  GGML_UNUSED(dst);
2596
2138
  GGML_UNUSED(src1_ddq_i);
@@ -2602,13 +2144,13 @@ catch (sycl::exception const &exc) {
2602
2144
  std::exit(1);
2603
2145
  }
2604
2146
 
2605
- static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2606
- const ggml_tensor *src1, ggml_tensor *dst,
2607
- const float *src0_dd, const float *src1_dd,
2608
- float *dst_dd, const queue_ptr &main_stream) {
2609
-
2610
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2147
+ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2148
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2611
2149
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2150
+ dpct::queue_ptr main_stream = ctx.stream();
2151
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2152
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2153
+ float * dst_dd = static_cast<float *>(dst->data);
2612
2154
 
2613
2155
  const int32_t * opts = (const int32_t *)dst->op_params;
2614
2156
  enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
@@ -2619,8 +2161,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2619
2161
  const int p0 = opts[5];
2620
2162
  const int p1 = opts[6];
2621
2163
 
2622
- const int64_t IH = src0->ne[1];
2623
- const int64_t IW = src0->ne[0];
2164
+ const int64_t IH = dst->src[0]->ne[1];
2165
+ const int64_t IW = dst->src[0]->ne[0];
2624
2166
 
2625
2167
  const int64_t N = dst->ne[3];
2626
2168
  const int64_t OC = dst->ne[2];
@@ -2639,163 +2181,101 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2639
2181
  parallel_elements, src0_dd, dst_dd, op,
2640
2182
  item_ct1);
2641
2183
  });
2642
-
2643
- GGML_UNUSED(src1);
2644
- GGML_UNUSED(src1_dd);
2645
- GGML_UNUSED(ctx);
2646
2184
  }
2647
2185
 
2648
- inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2649
- const ggml_tensor *src1, ggml_tensor *dst,
2650
- const float *src0_dd, const float *src1_dd,
2651
- float *dst_dd,
2652
- const queue_ptr &main_stream) {
2653
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2186
+ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2187
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2654
2188
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2189
+ dpct::queue_ptr main_stream = ctx.stream();
2190
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2191
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2192
+ float * dst_dd = static_cast<float *>(dst->data);
2655
2193
 
2656
- const int64_t ne = ggml_nelements(src0);
2194
+ const int64_t ne = ggml_nelements(dst->src[0]);
2657
2195
 
2658
2196
  sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
2659
-
2660
- GGML_UNUSED(src1);
2661
- GGML_UNUSED(dst);
2662
- GGML_UNUSED(src1_dd);
2663
- GGML_UNUSED(ctx);
2664
2197
  }
2665
2198
 
2666
- inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2667
- const ggml_tensor *src1, ggml_tensor *dst,
2668
- const float *src0_dd, const float *src1_dd,
2669
- float *dst_dd,
2670
- const queue_ptr &main_stream) {
2671
-
2672
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2199
+ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2200
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2673
2201
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2202
+ dpct::queue_ptr main_stream = ctx.stream();
2203
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2204
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2205
+ float * dst_dd = static_cast<float *>(dst->data);
2674
2206
 
2675
- const int64_t ncols = src0->ne[0];
2676
- const int64_t nrows = ggml_nrows(src0);
2207
+ const int64_t ncols = dst->src[0]->ne[0];
2208
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2677
2209
 
2678
2210
  sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2679
-
2680
- GGML_UNUSED(src1);
2681
- GGML_UNUSED(dst);
2682
- GGML_UNUSED(src1_dd);
2683
- GGML_UNUSED(ctx);
2684
2211
  }
2685
2212
 
2686
- inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2687
- const ggml_tensor *src1, ggml_tensor *dst,
2688
- const float *src0_dd, const float *src1_dd,
2689
- float *dst_dd,
2690
- const queue_ptr &main_stream) {
2213
+ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2214
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2215
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
2216
+ dpct::queue_ptr main_stream = ctx.stream();
2217
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2218
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2219
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2691
2220
 
2692
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2693
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2694
2221
 
2695
- const int64_t ncols = src0->ne[0];
2696
- const int64_t nrows = ggml_nrows(src0);
2222
+ const int64_t ncols = dst->src[0]->ne[0];
2223
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2697
2224
 
2698
2225
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2699
2226
 
2700
- argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
2701
-
2702
- GGML_UNUSED(src1);
2703
- GGML_UNUSED(dst);
2704
- GGML_UNUSED(src1_dd);
2705
- GGML_UNUSED(ctx);
2227
+ argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
2706
2228
  }
2707
2229
 
2708
- inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2709
- const ggml_tensor *src1, ggml_tensor *dst,
2710
- const float *src0_dd, const float *src1_dd,
2711
- float *dst_dd,
2712
- const queue_ptr &main_stream) {
2713
-
2714
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2230
+ inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2231
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2715
2232
  GGML_ASSERT( dst->type == GGML_TYPE_I32);
2716
2233
 
2717
- const int64_t ncols = src0->ne[0];
2718
- const int64_t nrows = ggml_nrows(src0);
2719
-
2720
- argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
2721
-
2722
- GGML_UNUSED(src1);
2723
- GGML_UNUSED(dst);
2724
- GGML_UNUSED(src1_dd);
2725
- GGML_UNUSED(ctx);
2726
- }
2727
-
2728
- inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2729
- const ggml_tensor *src1,
2730
- ggml_tensor *dst, const float *src0_dd,
2731
- const float *src1_dd, float *dst_dd,
2732
- const queue_ptr &main_stream) {
2733
-
2734
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2735
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
2736
-
2737
- const int64_t ne00 = src0->ne[0];
2738
- const int64_t ne01 = src0->ne[1];
2739
- const int nrows0 = ggml_nrows(src0);
2740
-
2741
- const int n_past = ((int32_t *) dst->op_params)[0];
2234
+ dpct::queue_ptr main_stream = ctx.stream();
2235
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2236
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2237
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2742
2238
 
2743
- diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2239
+ const int64_t ncols = dst->src[0]->ne[0];
2240
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2744
2241
 
2745
- GGML_UNUSED(src1);
2746
- GGML_UNUSED(dst);
2747
- GGML_UNUSED(src1_dd);
2748
- GGML_UNUSED(ctx);
2242
+ argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2749
2243
  }
2750
2244
 
2751
- inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2752
- ggml_tensor *dst, const float *src0_dd,
2753
- const float *src1_dd, float *dst_dd,
2754
- const queue_ptr &main_stream) {
2755
-
2756
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2245
+ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2246
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2757
2247
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2248
+ dpct::queue_ptr main_stream = ctx.stream();
2249
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2250
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2251
+ float * dst_dd = static_cast<float *>(dst->data);
2758
2252
 
2759
- float scale;
2760
- memcpy(&scale, dst->op_params, sizeof(float));
2253
+ const int64_t ne00 = dst->src[0]->ne[0];
2254
+ const int64_t ne01 = dst->src[0]->ne[1];
2255
+ const int nrows0 = ggml_nrows(dst->src[0]);
2761
2256
 
2762
- scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
2763
- /*
2764
- DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2765
- error codes. The call was replaced with 0. You need to rewrite this code.
2766
- */
2767
- SYCL_CHECK(0);
2257
+ const int n_past = ((int32_t *) dst->op_params)[0];
2768
2258
 
2769
- GGML_UNUSED(src1);
2770
- GGML_UNUSED(dst);
2771
- GGML_UNUSED(src1_dd);
2772
- GGML_UNUSED(ctx);
2259
+ diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2773
2260
  }
2774
2261
 
2775
- inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2776
- ggml_tensor *dst, const float *src0_dd,
2777
- const float *src1_dd, float *dst_dd,
2778
- const queue_ptr &main_stream) {
2779
-
2780
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2262
+ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2263
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2781
2264
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2265
+ dpct::queue_ptr main_stream = ctx.stream();
2266
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2267
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2268
+ float * dst_dd = static_cast<float *>(dst->data);
2782
2269
 
2783
- float min;
2784
- float max;
2785
- memcpy(&min, dst->op_params, sizeof(float));
2786
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
2270
+ float scale;
2271
+ memcpy(&scale, dst->op_params, sizeof(float));
2787
2272
 
2788
- clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
2273
+ scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
2789
2274
  /*
2790
- DPCT1010:88: SYCL uses exceptions to report errors and does not use the
2275
+ DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2791
2276
  error codes. The call was replaced with 0. You need to rewrite this code.
2792
2277
  */
2793
2278
  SYCL_CHECK(0);
2794
-
2795
- GGML_UNUSED(src1);
2796
- GGML_UNUSED(dst);
2797
- GGML_UNUSED(src1_dd);
2798
- GGML_UNUSED(ctx);
2799
2279
  }
2800
2280
 
2801
2281
  static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
@@ -2857,8 +2337,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2857
2337
  const int nb2 = dst->nb[2];
2858
2338
  const int nb3 = dst->nb[3];
2859
2339
 
2860
- GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2861
- GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2340
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
2341
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));
2862
2342
  GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
2863
2343
 
2864
2344
  GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
@@ -2878,7 +2358,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2878
2358
 
2879
2359
  int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
2880
2360
 
2881
- const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
2361
+ const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2882
2362
  GGML_ASSERT(!(split && ne02 > 1));
2883
2363
  GGML_ASSERT(!(split && ne03 > 1));
2884
2364
  GGML_ASSERT(!(split && ne02 < ne12));
@@ -2966,6 +2446,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2966
2446
  dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2967
2447
 
2968
2448
  if (src1_on_device && src1_is_contiguous) {
2449
+ scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2450
+ /*num_src=*/2, " : converting src1 to Q8_1");
2969
2451
  quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2970
2452
  /*
2971
2453
  DPCT1010:90: SYCL uses exceptions to report errors and does not
@@ -3002,7 +2484,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
3002
2484
  for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
3003
2485
  const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
3004
2486
  const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
3005
-
3006
2487
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
3007
2488
  if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
3008
2489
  continue;
@@ -3071,6 +2552,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
3071
2552
  }
3072
2553
 
3073
2554
  if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2555
+ scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2556
+ /*num_src=*/2, " : converting src1 to Q8_1");
3074
2557
  quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
3075
2558
  /*
3076
2559
  DPCT1010:92: SYCL uses exceptions to report errors and does
@@ -3164,41 +2647,36 @@ catch (sycl::exception const &exc) {
3164
2647
  }
3165
2648
 
3166
2649
 
3167
- static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3168
- GGML_SYCL_DEBUG("call %s\n", __func__);
3169
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
3170
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2650
+ static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2651
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
2652
+ ggml_sycl_op_get_rows(ctx, dst);
3171
2653
  }
3172
2654
 
3173
- static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3174
- GGML_SYCL_DEBUG("call %s\n", __func__);
3175
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
3176
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2655
+ static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2656
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2657
+ ggml_sycl_op_norm(ctx, dst);
3177
2658
  }
3178
2659
 
3179
- static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3180
- GGML_SYCL_DEBUG("call %s\n", __func__);
3181
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
3182
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2660
+ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2661
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2662
+ ggml_sycl_op_rms_norm(ctx, dst);
3183
2663
  }
3184
2664
 
3185
- static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3186
- GGML_SYCL_DEBUG("call %s\n", __func__);
3187
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
3188
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2665
+ static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2666
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2667
+ ggml_sycl_op_l2_norm(ctx, dst);
3189
2668
  }
3190
2669
 
3191
- static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3192
- GGML_SYCL_DEBUG("call %s\n", __func__);
3193
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
3194
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2670
+ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2671
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2672
+ ggml_sycl_op_group_norm(ctx, dst);
3195
2673
  }
3196
2674
 
3197
2675
  static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3198
2676
  const ggml_tensor *src1,
3199
2677
  ggml_tensor *dst) try {
3200
2678
  GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
3201
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2679
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3202
2680
  GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
3203
2681
  GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
3204
2682
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -3231,7 +2709,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
3231
2709
  GGML_ASSERT(!ggml_is_transposed(src0));
3232
2710
  GGML_ASSERT(!ggml_is_transposed(src1));
3233
2711
  GGML_ASSERT(!ggml_is_permuted(src0));
3234
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2712
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3235
2713
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
3236
2714
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
3237
2715
 
@@ -3262,146 +2740,182 @@ catch (sycl::exception const &exc) {
3262
2740
  std::exit(1);
3263
2741
  }
3264
2742
 
3265
- static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
3266
- const sycl::half *src1_as_f16, char *dst,
3267
- const void **ptrs_src, void **ptrs_dst,
3268
- int64_t ne12, int64_t ne13, int64_t ne23,
3269
- size_t nb02, size_t nb03, size_t nb12,
3270
- size_t nb13, size_t nbd2, size_t nbd3,
3271
- int64_t r2, int64_t r3,
3272
- const sycl::nd_item<3> &item_ct1) {
3273
- int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
3274
- item_ct1.get_local_id(2);
3275
- int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
3276
- item_ct1.get_local_id(1);
2743
+ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
2744
+ const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2745
+ size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2746
+ int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
2747
+ const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
2748
+ const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
3277
2749
 
3278
2750
  if (i13 >= ne13 || i12 >= ne12) {
3279
2751
  return;
3280
2752
  }
3281
2753
 
3282
- int64_t i03 = i13 / r3;
3283
- int64_t i02 = i12 / r2;
2754
+ const int64_t i03 = i13 / r3;
2755
+ const int64_t i02 = i12 / r2;
2756
+
2757
+ const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2758
+ const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2759
+ uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
3284
2760
 
3285
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
3286
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
3287
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
2761
+ ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2762
+ ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2763
+ ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
3288
2764
  }
3289
2765
 
3290
- static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3291
- const ggml_tensor *src0,
3292
- const ggml_tensor *src1,
3293
- ggml_tensor *dst) try {
2766
+ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2767
+ const ggml_tensor * src1, ggml_tensor * dst) try {
3294
2768
  GGML_ASSERT(!ggml_is_transposed(src0));
3295
2769
  GGML_ASSERT(!ggml_is_transposed(src1));
3296
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2770
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3297
2771
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2772
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
3298
2773
 
3299
2774
  GGML_TENSOR_BINARY_OP_LOCALS
3300
2775
 
2776
+ // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2777
+ // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2778
+ GGML_ASSERT(ggml_is_contiguous(dst));
3301
2779
 
3302
2780
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3303
- queue_ptr main_stream = ctx.stream();;
2781
+ queue_ptr queue = ctx.stream();
3304
2782
 
3305
- void * src0_ddq = src0->data;
3306
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
3307
- float * src1_ddf = (float *) src1->data;
3308
- float * dst_ddf = (float *) dst->data;
2783
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
3309
2784
 
3310
- // convert src1 to fp16
2785
+ const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
2786
+ float * dst_ddf = static_cast<float *>(dst->data);
2787
+
2788
+ const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2789
+ const size_t type_size_src1 = ggml_type_size(src1->type);
2790
+ GGML_ASSERT(nb10 == type_size_src1);
2791
+
2792
+ // SRC1 strides
2793
+ int64_t s11 = nb11 / type_size_src1;
2794
+ int64_t s12 = nb12 / type_size_src1;
2795
+ int64_t s13 = nb13 / type_size_src1;
3311
2796
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
2797
+
2798
+ // convert src1 to fp16
3312
2799
  if (src1->type != GGML_TYPE_F16) {
3313
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2800
+ scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2801
+ " : converting src1 to fp16");
2802
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2803
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
3314
2804
  const int64_t ne_src1 = ggml_nelements(src1);
3315
2805
  src1_f16_alloc.alloc(ne_src1);
3316
- GGML_ASSERT(to_fp16_sycl != nullptr);
3317
- to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
2806
+ to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2807
+
2808
+ src1_f16 = src1_f16_alloc.get();
2809
+ s11 = ne10;
2810
+ s12 = ne11 * s11;
2811
+ s13 = ne12 * s12;
3318
2812
  }
3319
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
3320
- : src1_f16_alloc.get();
3321
2813
 
3322
- char * dst_t;
2814
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
3323
2815
 
3324
- dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
3325
- dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
2816
+ dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
2817
+ dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
3326
2818
 
3327
2819
  // dst strides
3328
2820
  size_t nbd2 = dst->nb[2];
3329
2821
  size_t nbd3 = dst->nb[3];
3330
2822
 
3331
2823
  const float alpha_f32 = 1.0f;
3332
- const float beta_f32 = 0.0f;
2824
+ const float beta_f32 = 0.0f;
3333
2825
 
3334
2826
  const void * alpha = &alpha_f32;
3335
2827
  const void * beta = &beta_f32;
3336
2828
 
3337
- dst_t = (char *) dst_ddf;
3338
-
3339
2829
  GGML_ASSERT(ne12 % ne02 == 0);
3340
2830
  GGML_ASSERT(ne13 % ne03 == 0);
2831
+ GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
2832
+ GGML_ASSERT(ne10 == ne00);
3341
2833
 
3342
2834
  // broadcast factors
3343
- const int64_t r2 = ne12/ne02;
3344
- const int64_t r3 = ne13/ne03;
3345
-
3346
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
3347
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
3348
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3349
- *main_stream, oneapi::mkl::transpose::trans,
3350
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3351
- (const char *)src0_as_f16, dpct::library_data_t::real_half,
3352
- nb01 / nb00, nb02 / nb00,
3353
- (const char *)src1_f16, dpct::library_data_t::real_half,
3354
- nb11 / nb10, nb12 / nb10, beta,
3355
- (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
3356
- ne12 * ne13, cu_compute_type)));
3357
- } else {
3358
- const int ne23 = ne12*ne13;
3359
-
3360
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
3361
- ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
3362
-
3363
- sycl::range<3> block_dims(1, ne12, ne13);
3364
- /*
3365
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
3366
- the limit. To get the device limit, query
3367
- info::device::max_work_group_size. Adjust the work-group size if needed.
3368
- */
3369
- {
3370
- dpct::has_capability_or_fail(main_stream->get_device(),
3371
- {sycl::aspect::fp16});
3372
-
3373
- main_stream->submit([&](sycl::handler &cgh) {
3374
- const void **ptrs_src_get = ptrs_src.get();
3375
- void **ptrs_dst_get = ptrs_dst.get();
3376
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
3377
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
3378
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
3379
- [=](sycl::nd_item<3> item_ct1) {
3380
- k_compute_batched_ptrs(
3381
- src0_as_f16, src1_f16,
3382
- dst_t, ptrs_src_get,
3383
- ptrs_dst_get, ne12, ne13, ne23,
3384
- nb02, nb03, nb12_scaled, nb13_scaled,
3385
- nbd2, nbd3, r2, r3, item_ct1);
3386
- });
2835
+ const int64_t r2 = ne12 / ne02;
2836
+ const int64_t r3 = ne13 / ne03;
2837
+
2838
+ #if GGML_SYCL_DNNL
2839
+ if (!g_ggml_sycl_disable_dnn) {
2840
+ auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2841
+ (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2842
+
2843
+ DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2844
+ src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2845
+ src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2846
+ dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2847
+ };
2848
+
2849
+ if (r2 == 1 && r3 == 1) {
2850
+ if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2851
+ dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2852
+ }
2853
+ else {
2854
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2855
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2856
+ const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2857
+ float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2858
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2859
+ }
2860
+ }
2861
+ } else {
2862
+ // iterate over batches from smaller set of matrices (matrix 0)
2863
+ for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2864
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2865
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2866
+ const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2867
+ float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2868
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2869
+ }
2870
+ }
2871
+ }
2872
+ }
2873
+ else
2874
+ #endif
2875
+ {
2876
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2877
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2878
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2879
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2880
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2881
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2882
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2883
+ } else {
2884
+ const int ne23 = ne12 * ne13;
2885
+
2886
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
2887
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
2888
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2889
+
2890
+ sycl::range<3> block_dims(1, ne12, ne13);
2891
+ queue->submit([&](sycl::handler & cgh) {
2892
+ const void ** ptrs_src_get = ptrs_src.get();
2893
+ void ** ptrs_dst_get = ptrs_dst.get();
2894
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2895
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2896
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2897
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2898
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2899
+ });
3387
2900
  });
2901
+
2902
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2903
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2904
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2905
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2906
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
3388
2907
  }
3389
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3390
- *main_stream, oneapi::mkl::transpose::trans,
3391
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3392
- (const void **)(ptrs_src.get() + 0 * ne23),
3393
- dpct::library_data_t::real_half, nb01 / nb00,
3394
- (const void **)(ptrs_src.get() + 1 * ne23),
3395
- dpct::library_data_t::real_half, nb11 / nb10, beta,
3396
- (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3397
- cu_compute_type)));
3398
2908
  }
2909
+ } catch (const sycl::exception & exc) {
2910
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2911
+ std::exit(1);
3399
2912
  }
3400
- catch (sycl::exception const &exc) {
3401
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3402
- << ", line:" << __LINE__ << std::endl;
3403
- std::exit(1);
3404
- }
2913
+
2914
+ enum class mul_mat_algo {
2915
+ DMMV = 0,
2916
+ MMVQ = 1,
2917
+ MUL_MAT_SYCL = 2,
2918
+ };
3405
2919
 
3406
2920
  inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
3407
2921
  // TODO: accuracy issues in MMQ
@@ -3409,7 +2923,37 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
3409
2923
  return false;
3410
2924
  }
3411
2925
 
3412
- bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2926
+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
2927
+ switch (type) {
2928
+ case GGML_TYPE_Q4_0:
2929
+ return true;
2930
+ case GGML_TYPE_Q4_K:
2931
+ return !g_ggml_sycl_prioritize_dmmv;
2932
+ default:
2933
+ return false;
2934
+ }
2935
+ }
2936
+
2937
+ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
2938
+ switch (type) {
2939
+ case GGML_TYPE_Q4_0:
2940
+ return true;
2941
+ default:
2942
+ return false;
2943
+ }
2944
+ }
2945
+
2946
+ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2947
+ switch (type) {
2948
+ case GGML_TYPE_Q4_0:
2949
+ case GGML_TYPE_Q4_K:
2950
+ return true;
2951
+ default:
2952
+ return false;
2953
+ }
2954
+ }
2955
+
2956
+ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3413
2957
  switch (type) {
3414
2958
  case GGML_TYPE_Q4_0:
3415
2959
  case GGML_TYPE_Q4_1:
@@ -3428,12 +2972,143 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3428
2972
  }
3429
2973
  }
3430
2974
 
2975
+ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
2976
+ dpct::queue_ptr stream) {
2977
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2978
+ SYCL_CHECK(
2979
+ CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2980
+ .wait()));
2981
+ GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2982
+ GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2983
+ int offset_blks = offset / sizeof(block_q4_0);
2984
+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
2985
+ auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2986
+
2987
+ stream->parallel_for(
2988
+ size / sizeof(block_q4_0),
2989
+ [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2990
+ const block_q4_0* x = (const block_q4_0*)tmp_buf;
2991
+ const int ib = i;
2992
+
2993
+ for (int j = 0; j < QK4_0/2; j ++)
2994
+ {
2995
+ *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
2996
+ }
2997
+ *(d_ptr + ib) = x[ib].d;
2998
+ }).wait_and_throw();
2999
+
3000
+ sycl::free(tmp_buf, *stream);
3001
+ }
3002
+
3003
+ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3004
+ GGML_ASSERT(size % sizeof(block_q4_K) == 0);
3005
+ GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
3006
+
3007
+ const int nblocks = size / sizeof(block_q4_K);
3008
+
3009
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3010
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3011
+
3012
+ auto * qs_ptr = data_device;
3013
+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3014
+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3015
+
3016
+ stream->parallel_for(nblocks, [=](auto i) {
3017
+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
3018
+ const int ib = i;
3019
+
3020
+ for (int j = 0; j < QK_K / 2; ++j) {
3021
+ qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
3022
+ }
3023
+
3024
+ for (int j = 0; j < K_SCALE_SIZE; ++j) {
3025
+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
3026
+ }
3027
+
3028
+ dm_ptr[ib] = x[ib].dm;
3029
+ }).wait_and_throw();
3030
+
3031
+ sycl::free(tmp_buf, *stream);
3032
+ }
3033
+
3034
+ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3035
+ uint8_t * data_device = (uint8_t *) src0->data;
3036
+ size_t ncols = src0->ne[0];
3037
+ size_t nrows = src0->ne[1];
3038
+ size_t size = ggml_nbytes(src0);
3039
+
3040
+ switch (src0->type) {
3041
+ case GGML_TYPE_Q4_0:
3042
+ reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3043
+ break;
3044
+ case GGML_TYPE_Q4_K:
3045
+ reorder_qw_q4_k(data_device, size, 0, stream);
3046
+ break;
3047
+ default:
3048
+ GGML_ABORT("reorder_qw() called with unsupported type");
3049
+ break;
3050
+ }
3051
+ }
3052
+
3053
+ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
3054
+ return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
3055
+ ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
3056
+ dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
3057
+ dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
3058
+ }
3059
+
3060
+ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
3061
+ ggml_tensor * dst, mul_mat_algo mm_algorithm) {
3062
+ if (!should_reorder_tensor(*ctx, dst)) {
3063
+ return;
3064
+ }
3065
+
3066
+ ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3067
+ if (!extra || extra->optimized_feature.reorder) {
3068
+ return; // Skip permutations and already reordered tensors
3069
+ }
3070
+
3071
+ switch (mm_algorithm) {
3072
+ case mul_mat_algo::DMMV:
3073
+ if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
3074
+ return;
3075
+ }
3076
+ break;
3077
+ case mul_mat_algo::MMVQ:
3078
+ if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
3079
+ return;
3080
+ }
3081
+ break;
3082
+ case mul_mat_algo::MUL_MAT_SYCL:
3083
+ if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
3084
+ return;
3085
+ }
3086
+ break;
3087
+ }
3088
+
3089
+ reorder_qw(src0, ctx->stream());
3090
+ extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
3091
+ }
3092
+
3093
+
3094
+ static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3095
+ return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3096
+ src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3097
+ }
3098
+
3099
+ static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3100
+ return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3101
+ src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3102
+ }
3103
+
3431
3104
  static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3105
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3432
3106
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3433
3107
  int64_t min_compute_capability = INT_MAX;
3434
3108
 
3435
3109
  if (split) {
3436
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3110
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
3111
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3437
3112
  auto & tensor_split = buft_ctx->tensor_split;
3438
3113
  for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
3439
3114
  // skip devices that are not going to do any work:
@@ -3446,17 +3121,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3446
3121
  }
3447
3122
  }
3448
3123
  } else {
3449
- min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3124
+ min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3450
3125
  }
3451
3126
 
3452
3127
  // check data types and tensor shapes for custom matrix multiplication kernels:
3453
- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
3454
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3455
- && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3128
+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
3456
3129
 
3457
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
3458
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3459
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3130
+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
3460
3131
 
3461
3132
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
3462
3133
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -3468,9 +3139,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3468
3139
  use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
3469
3140
  #endif // SYCL_USE_XMX
3470
3141
 
3142
+
3471
3143
  // mmvq path is faster in the CUDA backend.
3472
- if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
3144
+ if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
3145
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3146
+ // is enabled takes precedence over DMMV, the current if-else implementation
3147
+ // requires disabling DMMV if both conditions are met
3148
+ || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
3473
3149
  use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3150
+ }
3474
3151
 
3475
3152
  if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
3476
3153
  // TODO: Refactor and cleanup of mul mat dispatching.
@@ -3482,20 +3159,26 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3482
3159
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3483
3160
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3484
3161
  }
3485
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3162
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3486
3163
  // KQV single-batch
3487
3164
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3488
3165
  } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3489
3166
  // KQ + KQV multi-batch
3490
3167
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3491
3168
  } else if (use_dequantize_mul_mat_vec) {
3492
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
3169
+ constexpr bool convert_src1_to_q8_1 = false;
3170
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3171
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3493
3172
  } else if (use_mul_mat_vec_q) {
3494
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3173
+ constexpr bool convert_src1_to_q8_1 = true;
3174
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3175
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3495
3176
  } else if (use_mul_mat_q) {
3496
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
3177
+ constexpr bool convert_src1_to_q8_1 = true;
3178
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3497
3179
  } else {
3498
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
3180
+ constexpr bool convert_src1_to_q8_1 = false;
3181
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3499
3182
  }
3500
3183
  }
3501
3184
 
@@ -3565,9 +3248,11 @@ __dpct_inline__ static void k_copy_dst_from_contiguous(
3565
3248
  }
3566
3249
  }
3567
3250
 
3568
- static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3569
- const ggml_tensor *src1,
3251
+ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3570
3252
  ggml_tensor *dst) try {
3253
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
3254
+ const ggml_tensor *src0 = dst->src[0];
3255
+ const ggml_tensor *src1 = dst->src[1];
3571
3256
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
3572
3257
 
3573
3258
  const ggml_tensor *ids = dst->src[2];
@@ -3621,8 +3306,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
3621
3306
  const int64_t i2 = i12;
3622
3307
 
3623
3308
  src0_row.data = src0_original + i02*nb02;
3624
- src1_row.data = src1_original + + i11*nb11 + i12*nb12;
3625
- dst_row.data = dst_original + i1*nb1 + i2*nb2;
3309
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
3310
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
3626
3311
 
3627
3312
  ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3628
3313
  }
@@ -3733,117 +3418,52 @@ catch (sycl::exception const &exc) {
3733
3418
  std::exit(1);
3734
3419
  }
3735
3420
 
3736
- static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3737
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
3738
- }
3739
-
3740
- static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3741
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
3742
- }
3743
-
3744
- static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3745
- ggml_tensor *dst) try {
3746
- const int64_t ne = ggml_nelements(src0);
3747
- GGML_ASSERT(ne == ggml_nelements(src1));
3748
-
3749
- GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
3750
- GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
3751
-
3752
- GGML_TENSOR_BINARY_OP_LOCALS01;
3753
-
3754
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3755
- queue_ptr main_stream = ctx.stream();
3756
-
3757
- char * src0_ddc = (char *) src0->data;
3758
- char * src1_ddc = (char *) src1->data;
3759
-
3760
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
3761
- ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3762
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
3763
- ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3764
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
3765
- ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3766
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
3767
- ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3768
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
3769
- ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3770
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
3771
- ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3772
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
3773
- ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3774
- } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
3775
- ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3776
- } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
3777
- ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3778
- } else {
3779
- GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__,
3780
- ggml_type_name(src0->type), ggml_type_name(src1->type));
3781
- GGML_ABORT("fatal error");
3782
- }
3783
-
3784
- GGML_UNUSED(dst);
3785
- }
3786
- catch (sycl::exception const &exc) {
3787
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3788
- << ", line:" << __LINE__ << std::endl;
3789
- std::exit(1);
3790
- }
3791
-
3792
- static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3793
- // TODO: why do we pass dst as src1 here?
3794
- ggml_sycl_cpy(ctx, src0, dst, nullptr);
3795
- GGML_UNUSED(src1);
3796
- }
3797
-
3798
- static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3799
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
3800
- }
3801
-
3802
- static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3803
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
3421
+ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3422
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3423
+ ggml_sycl_op_scale(ctx, dst);
3804
3424
  }
3805
3425
 
3806
- static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3807
- GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
3808
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
3426
+ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3427
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3428
+ ggml_sycl_op_diag_mask_inf(ctx, dst);
3809
3429
  }
3810
3430
 
3811
- static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3812
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
3431
+ static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3432
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3433
+ ggml_sycl_op_pool2d(ctx, dst);
3813
3434
  }
3814
3435
 
3815
- static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3816
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
3436
+ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3437
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
3438
+ ggml_sycl_op_im2col(ctx, dst);
3817
3439
  }
3818
3440
 
3819
- static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3820
- GGML_ASSERT(ggml_is_contiguous(src0));
3821
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
3441
+ static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3442
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3443
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3444
+ ggml_sycl_op_sum(ctx, dst);
3822
3445
  }
3823
3446
 
3824
- static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3825
- GGML_ASSERT(ggml_is_contiguous(src0));
3826
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
3447
+ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3448
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3449
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3450
+ ggml_sycl_op_sum_rows(ctx, dst);
3827
3451
  }
3828
3452
 
3829
- static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3830
- GGML_ASSERT(ggml_is_contiguous(src0));
3831
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
3453
+ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3454
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3455
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3456
+ ggml_sycl_op_argsort(ctx, dst);
3832
3457
  }
3833
3458
 
3834
- static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3835
- GGML_ASSERT(ggml_is_contiguous(src0));
3836
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
3459
+ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3460
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3461
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3462
+ ggml_sycl_op_argmax(ctx, dst);
3837
3463
  }
3838
3464
 
3839
- static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3840
- GGML_UNUSED(src0);
3841
- GGML_UNUSED(src1);
3842
- GGML_UNUSED(dst);
3843
- GGML_UNUSED(ctx);
3844
- }
3845
3465
 
3846
- void ggml_sycl_set_main_device(const int main_device) try {
3466
+ static void ggml_sycl_set_main_device(const int main_device) try {
3847
3467
  if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
3848
3468
  return;
3849
3469
  }
@@ -3864,192 +3484,211 @@ catch (sycl::exception const &exc) {
3864
3484
  std::exit(1);
3865
3485
  }
3866
3486
 
3867
- bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
3487
+ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
3868
3488
  if (!g_sycl_loaded) return false;
3869
3489
 
3870
- ggml_sycl_func_t func;
3490
+ if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
3491
+ ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
3492
+ }
3871
3493
 
3872
- switch (tensor->op) {
3494
+ switch (dst->op) {
3873
3495
  case GGML_OP_ARGMAX:
3874
- func = ggml_sycl_argmax;
3496
+ ggml_sycl_argmax(ctx, dst);
3875
3497
  break;
3876
3498
  case GGML_OP_CONV_TRANSPOSE_1D:
3877
- func = ggml_sycl_op_conv_transpose_1d;
3499
+ ggml_sycl_op_conv_transpose_1d(ctx, dst);
3878
3500
  break;
3879
3501
  case GGML_OP_REPEAT:
3880
- func = ggml_sycl_repeat;
3502
+ ggml_sycl_repeat(ctx, dst);
3881
3503
  break;
3882
3504
  case GGML_OP_GET_ROWS:
3883
- func = ggml_sycl_get_rows;
3505
+ ggml_sycl_get_rows(ctx, dst);
3884
3506
  break;
3885
3507
  case GGML_OP_DUP:
3886
- func = ggml_sycl_dup;
3508
+ ggml_sycl_dup(ctx, dst);
3887
3509
  break;
3888
3510
  case GGML_OP_ADD:
3889
3511
  case GGML_OP_ADD1: // TODO: more efficient implementation
3890
- func = ggml_sycl_add;
3512
+ ggml_sycl_add(ctx, dst);
3891
3513
  break;
3892
3514
  case GGML_OP_SUB:
3893
- func = ggml_sycl_sub;
3515
+ ggml_sycl_sub(ctx, dst);
3894
3516
  break;
3895
3517
  case GGML_OP_ACC:
3896
- func = ggml_sycl_acc;
3518
+ ggml_sycl_acc(ctx, dst);
3897
3519
  break;
3898
3520
  case GGML_OP_MUL:
3899
- func = ggml_sycl_mul;
3521
+ ggml_sycl_mul(ctx, dst);
3900
3522
  break;
3901
3523
  case GGML_OP_LOG:
3902
- func = ggml_sycl_log;
3524
+ ggml_sycl_log(ctx, dst);
3903
3525
  break;
3904
3526
  case GGML_OP_DIV:
3905
- func = ggml_sycl_div;
3527
+ ggml_sycl_div(ctx, dst);
3906
3528
  break;
3907
3529
  case GGML_OP_UNARY:
3908
- switch (ggml_get_unary_op(tensor)) {
3530
+ switch (ggml_get_unary_op(dst)) {
3909
3531
  case GGML_UNARY_OP_NEG:
3910
- func = ggml_sycl_neg;
3532
+ ggml_sycl_neg(ctx, dst);
3911
3533
  break;
3912
3534
  case GGML_UNARY_OP_STEP:
3913
- func = ggml_sycl_step;
3535
+ ggml_sycl_step(ctx, dst);
3914
3536
  break;
3915
3537
  case GGML_UNARY_OP_GELU:
3916
- func = ggml_sycl_gelu;
3538
+ ggml_sycl_gelu(ctx, dst);
3917
3539
  break;
3918
3540
  case GGML_UNARY_OP_SILU:
3919
- func = ggml_sycl_silu;
3541
+ ggml_sycl_silu(ctx, dst);
3920
3542
  break;
3921
3543
  case GGML_UNARY_OP_GELU_QUICK:
3922
- func = ggml_sycl_gelu_quick;
3544
+ ggml_sycl_gelu_quick(ctx, dst);
3923
3545
  break;
3924
3546
  case GGML_UNARY_OP_TANH:
3925
- func = ggml_sycl_tanh;
3547
+ ggml_sycl_tanh(ctx, dst);
3926
3548
  break;
3927
3549
  case GGML_UNARY_OP_RELU:
3928
- func = ggml_sycl_relu;
3550
+ ggml_sycl_relu(ctx, dst);
3929
3551
  break;
3930
3552
  case GGML_UNARY_OP_SIGMOID:
3931
- func = ggml_sycl_sigmoid;
3553
+ ggml_sycl_sigmoid(ctx, dst);
3932
3554
  break;
3933
3555
  case GGML_UNARY_OP_HARDSIGMOID:
3934
- func = ggml_sycl_hardsigmoid;
3556
+ ggml_sycl_hardsigmoid(ctx, dst);
3935
3557
  break;
3936
3558
  case GGML_UNARY_OP_HARDSWISH:
3937
- func = ggml_sycl_hardswish;
3559
+ ggml_sycl_hardswish(ctx, dst);
3938
3560
  break;
3939
3561
  case GGML_UNARY_OP_EXP:
3940
- func = ggml_sycl_exp;
3562
+ ggml_sycl_exp(ctx, dst);
3563
+ break;
3564
+ case GGML_UNARY_OP_SGN:
3565
+ ggml_sycl_sgn(ctx, dst);
3566
+ break;
3567
+ case GGML_UNARY_OP_ABS:
3568
+ ggml_sycl_abs(ctx, dst);
3569
+ break;
3570
+ case GGML_UNARY_OP_ELU:
3571
+ ggml_sycl_elu(ctx, dst);
3941
3572
  break;
3942
3573
  default:
3943
3574
  return false;
3944
3575
  }
3945
3576
  break;
3946
3577
  case GGML_OP_NORM:
3947
- func = ggml_sycl_norm;
3578
+ ggml_sycl_norm(ctx, dst);
3948
3579
  break;
3949
3580
  case GGML_OP_GROUP_NORM:
3950
- func = ggml_sycl_group_norm;
3581
+ ggml_sycl_group_norm(ctx, dst);
3951
3582
  break;
3952
3583
  case GGML_OP_CONCAT:
3953
- func = ggml_sycl_op_concat;
3584
+ ggml_sycl_op_concat(ctx, dst);
3954
3585
  break;
3955
3586
  case GGML_OP_UPSCALE:
3956
- func = ggml_sycl_upscale;
3587
+ ggml_sycl_upscale(ctx, dst);
3957
3588
  break;
3958
3589
  case GGML_OP_PAD:
3959
- func = ggml_sycl_pad;
3590
+ ggml_sycl_pad(ctx, dst);
3960
3591
  break;
3961
3592
  case GGML_OP_LEAKY_RELU:
3962
- func = ggml_sycl_leaky_relu;
3593
+ ggml_sycl_leaky_relu(ctx, dst);
3963
3594
  break;
3964
3595
  case GGML_OP_RMS_NORM:
3965
- func = ggml_sycl_rms_norm;
3596
+ ggml_sycl_rms_norm(ctx, dst);
3597
+ break;
3598
+ case GGML_OP_L2_NORM:
3599
+ ggml_sycl_l2_norm(ctx, dst);
3966
3600
  break;
3967
3601
  case GGML_OP_MUL_MAT:
3968
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
3602
+ if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3969
3603
  return false;
3970
3604
  }
3971
- func = ggml_sycl_mul_mat;
3605
+ /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
3606
+ ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
3972
3607
  break;
3973
3608
  case GGML_OP_MUL_MAT_ID:
3974
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
3609
+ if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3975
3610
  return false;
3976
3611
  }
3977
- func = ggml_sycl_mul_mat_id;
3612
+ ggml_sycl_mul_mat_id(ctx, dst);
3978
3613
  break;
3979
3614
  case GGML_OP_OUT_PROD:
3980
- func = ggml_sycl_op_out_prod;
3615
+ ggml_sycl_op_out_prod(ctx, dst);
3981
3616
  break;
3982
3617
  case GGML_OP_SCALE:
3983
- func = ggml_sycl_scale;
3618
+ ggml_sycl_scale(ctx, dst);
3984
3619
  break;
3985
3620
  case GGML_OP_SQR:
3986
- func = ggml_sycl_sqr;
3621
+ ggml_sycl_sqr(ctx, dst);
3987
3622
  break;
3988
3623
  case GGML_OP_SQRT:
3989
- func = ggml_sycl_sqrt;
3624
+ ggml_sycl_sqrt(ctx, dst);
3990
3625
  break;
3991
3626
  case GGML_OP_SIN:
3992
- func = ggml_sycl_sin;
3627
+ ggml_sycl_sin(ctx, dst);
3993
3628
  break;
3994
3629
  case GGML_OP_COS:
3995
- func = ggml_sycl_cos;
3630
+ ggml_sycl_cos(ctx, dst);
3996
3631
  break;
3997
3632
  case GGML_OP_CLAMP:
3998
- func = ggml_sycl_clamp;
3633
+ ggml_sycl_clamp(ctx, dst);
3999
3634
  break;
4000
3635
  case GGML_OP_CPY:
4001
- func = ggml_sycl_cpy;
3636
+ ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
4002
3637
  break;
4003
3638
  case GGML_OP_CONT:
4004
- func = ggml_sycl_dup;
3639
+ ggml_sycl_dup(ctx, dst);
4005
3640
  break;
4006
3641
  case GGML_OP_NONE:
4007
3642
  case GGML_OP_RESHAPE:
4008
3643
  case GGML_OP_VIEW:
4009
3644
  case GGML_OP_PERMUTE:
4010
3645
  case GGML_OP_TRANSPOSE:
4011
- func = ggml_sycl_nop;
3646
+ GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
4012
3647
  break;
4013
3648
  case GGML_OP_DIAG_MASK_INF:
4014
- func = ggml_sycl_diag_mask_inf;
3649
+ ggml_sycl_diag_mask_inf(ctx, dst);
4015
3650
  break;
4016
3651
  case GGML_OP_SOFT_MAX:
4017
- func = ggml_sycl_soft_max;
3652
+ ggml_sycl_op_soft_max(ctx, dst);
4018
3653
  break;
4019
3654
  case GGML_OP_ROPE:
4020
- func = ggml_sycl_rope;
3655
+ ggml_sycl_rope(ctx, dst);
4021
3656
  break;
4022
3657
  case GGML_OP_IM2COL:
4023
- func = ggml_sycl_im2col;
3658
+ ggml_sycl_im2col(ctx, dst);
4024
3659
  break;
4025
3660
  case GGML_OP_POOL_2D:
4026
- func = ggml_sycl_pool2d;
3661
+ ggml_sycl_pool2d(ctx, dst);
4027
3662
  break;
4028
3663
  case GGML_OP_SUM:
4029
- func = ggml_sycl_sum;
3664
+ ggml_sycl_sum(ctx, dst);
4030
3665
  break;
4031
3666
  case GGML_OP_SUM_ROWS:
4032
- func = ggml_sycl_sum_rows;
3667
+ ggml_sycl_sum_rows(ctx, dst);
4033
3668
  break;
4034
3669
  case GGML_OP_ARGSORT:
4035
- func = ggml_sycl_argsort;
3670
+ ggml_sycl_argsort(ctx, dst);
4036
3671
  break;
4037
3672
  case GGML_OP_TIMESTEP_EMBEDDING:
4038
- func = ggml_sycl_op_timestep_embedding;
3673
+ ggml_sycl_op_timestep_embedding(ctx, dst);
4039
3674
  break;
4040
3675
  case GGML_OP_RWKV_WKV6:
4041
- func = ggml_sycl_op_rwkv_wkv6;
3676
+ ggml_sycl_op_rwkv_wkv6(ctx, dst);
3677
+ break;
3678
+ case GGML_OP_RWKV_WKV7:
3679
+ ggml_sycl_op_rwkv_wkv7(ctx, dst);
3680
+ break;
3681
+ case GGML_OP_GATED_LINEAR_ATTN:
3682
+ ggml_sycl_op_gated_linear_attn(ctx, dst);
4042
3683
  break;
4043
3684
  default:
4044
3685
  return false;
4045
3686
  }
4046
3687
 
4047
- if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
4048
- ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
4049
- }
4050
-
4051
- func(ctx, tensor->src[0], tensor->src[1], tensor);
4052
3688
  return true;
3689
+ } catch (sycl::exception & e) {
3690
+ std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3691
+ std::exit(1);
4053
3692
  }
4054
3693
 
4055
3694
  GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
@@ -4112,6 +3751,9 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
4112
3751
  ggml_tensor *tensor,
4113
3752
  const void *data, size_t offset,
4114
3753
  size_t size) try {
3754
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3755
+ debug_print_tensor(": tensor=", tensor);
3756
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4115
3757
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4116
3758
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4117
3759
 
@@ -4130,13 +3772,16 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
4130
3772
  const ggml_tensor *tensor,
4131
3773
  void *data, size_t offset,
4132
3774
  size_t size) try {
3775
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3776
+ debug_print_tensor(": tensor=", tensor);
3777
+ GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
4133
3778
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4134
3779
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
4135
3780
 
4136
3781
  GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
4137
3782
  const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4138
3783
  SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4139
- data, (const char *)tensor->data + offset, size).wait()));
3784
+ data, (const char *)tensor->data + offset, size)));
4140
3785
  }
4141
3786
  catch (sycl::exception const &exc) {
4142
3787
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4148,7 +3793,13 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
4148
3793
  const ggml_tensor *src,
4149
3794
  ggml_tensor *dst) try {
4150
3795
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4151
- if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
3796
+ bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&
3797
+ ggml_backend_buffer_is_sycl(src->buffer);
3798
+ GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3799
+ debug_print_tensor(": dst=", dst);
3800
+ debug_print_tensor(" src=", src);
3801
+ GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
3802
+ if (is_cpy_supported) {
4152
3803
  /*
4153
3804
  DPCT1009:215: SYCL uses exceptions to report errors and does not use the
4154
3805
  error codes. The original code was commented out and a warning string
@@ -4156,7 +3807,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
4156
3807
  */
4157
3808
  const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4158
3809
  SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
4159
- dst->data, src->data, ggml_nbytes(dst)).wait()));
3810
+ dst->data, src->data, ggml_nbytes(dst))));
4160
3811
  return true;
4161
3812
  }
4162
3813
 
@@ -4169,6 +3820,7 @@ catch (sycl::exception const &exc) {
4169
3820
  }
4170
3821
 
4171
3822
  static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
3823
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4172
3824
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4173
3825
  const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
4174
3826
  SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
@@ -4181,11 +3833,9 @@ catch (sycl::exception const &exc) {
4181
3833
  std::exit(1);
4182
3834
  }
4183
3835
 
4184
- static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4185
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
3836
+ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
4186
3837
  ggml_sycl_set_main_device(sycl_ctx->device);
4187
3838
 
4188
-
4189
3839
  for (int i = 0; i < cgraph->n_nodes; i++) {
4190
3840
  ggml_tensor * node = cgraph->nodes[i];
4191
3841
  if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
@@ -4205,7 +3855,82 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
4205
3855
  }
4206
3856
  GGML_ASSERT(ok);
4207
3857
  }
3858
+ }
3859
+
3860
+ #ifdef GGML_SYCL_GRAPH
3861
+ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
3862
+ if (ggml_sycl_info().device_count > 1) {
3863
+ // A sycl_ex::command_graph object can only be created for a single device
3864
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
3865
+ return false;
3866
+ }
3867
+
3868
+ for (int i = 0; i < cgraph->n_nodes; i++) {
3869
+ const ggml_op node_op = cgraph->nodes[i]->op;
3870
+ switch (node_op) {
3871
+ default:
3872
+ break;
3873
+ case GGML_OP_CONCAT:
3874
+ // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
3875
+ // but wait() can't be called on the events returned by a queue recording
3876
+ // to a graph.
3877
+ [[fallthrough]];
3878
+ case GGML_OP_MUL_MAT_ID:
3879
+ // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
3880
+ // submitting a memcpy operation, but wait() can't be called on a queue that
3881
+ // is recording to a graph.
3882
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
3883
+ ggml_op_name(node_op));
3884
+ return false;
3885
+ }
3886
+ }
3887
+ return true;
3888
+ }
3889
+ #endif
3890
+
3891
+ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3892
+ auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
3893
+
3894
+ #ifdef GGML_SYCL_GRAPH
3895
+ bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
3896
+ if (use_sycl_graph) {
3897
+ const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
3898
+ if (!graph_support) {
3899
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
3900
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3901
+ return GGML_STATUS_SUCCESS;
3902
+ }
3903
+
3904
+ sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
3905
+
3906
+ model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
3907
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3908
+ model_sycl_graph.end_recording();
3909
+
3910
+ const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
3911
+ if (!sycl_ctx->exec_graph || !graph_update_support) {
3912
+ auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
3913
+ model_sycl_graph.finalize();
3914
+ sycl_ctx->exec_graph = std::make_unique<
3915
+ sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
3916
+ } else {
3917
+ try {
3918
+ sycl_ctx->exec_graph->update(model_sycl_graph);
3919
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
3920
+ } catch (sycl::exception const & e) {
3921
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
3922
+ auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
3923
+ sycl_ctx->exec_graph = std::make_unique<
3924
+ sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
3925
+ }
3926
+ }
4208
3927
 
3928
+ sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
3929
+ } else
3930
+ #endif
3931
+ {
3932
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3933
+ }
4209
3934
  return GGML_STATUS_SUCCESS;
4210
3935
  }
4211
3936
 
@@ -4229,7 +3954,7 @@ catch (sycl::exception const &exc)
4229
3954
  }
4230
3955
 
4231
3956
  static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
4232
-
3957
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4233
3958
  sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
4234
3959
 
4235
3960
  if (ggml_backend_is_sycl(backend)) {
@@ -4270,7 +3995,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
4270
3995
  }
4271
3996
 
4272
3997
  int ggml_backend_sycl_get_device_count() {
4273
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
4274
3998
  return ggml_sycl_info().device_count;
4275
3999
  }
4276
4000
 
@@ -4360,7 +4084,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4360
4084
  return true;
4361
4085
  }
4362
4086
  return false;
4363
- } break;
4087
+ }
4364
4088
  case GGML_OP_UNARY:
4365
4089
  switch (ggml_get_unary_op(op)) {
4366
4090
  case GGML_UNARY_OP_NEG:
@@ -4374,11 +4098,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4374
4098
  case GGML_UNARY_OP_GELU_QUICK:
4375
4099
  case GGML_UNARY_OP_TANH:
4376
4100
  case GGML_UNARY_OP_EXP:
4377
- return ggml_is_contiguous(op->src[0]);
4101
+ case GGML_UNARY_OP_SGN:
4102
+ case GGML_UNARY_OP_ABS:
4103
+ case GGML_UNARY_OP_ELU:
4104
+ #if defined (GGML_SYCL_F16)
4105
+ return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4106
+ #else
4107
+ return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4108
+ #endif
4378
4109
  default:
4379
4110
  return false;
4380
4111
  }
4381
- break;
4382
4112
  case GGML_OP_MUL_MAT:
4383
4113
  case GGML_OP_MUL_MAT_ID:
4384
4114
  {
@@ -4409,7 +4139,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4409
4139
  return false;
4410
4140
  }
4411
4141
  return true;
4412
- } break;
4142
+ }
4413
4143
  case GGML_OP_OUT_PROD:
4414
4144
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
4415
4145
  case GGML_OP_GET_ROWS:
@@ -4426,7 +4156,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4426
4156
  default:
4427
4157
  return false;
4428
4158
  }
4429
- } break;
4159
+ }
4430
4160
  case GGML_OP_CPY:
4431
4161
  {
4432
4162
  ggml_type src0_type = op->src[0]->type;
@@ -4452,35 +4182,70 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4452
4182
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4453
4183
  return true;
4454
4184
  }
4185
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
4186
+ return true;
4187
+ }
4188
+ if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
4189
+ return true;
4190
+ }
4191
+ if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
4192
+ return true;
4193
+ }
4194
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
4195
+ return true;
4196
+ }
4197
+ if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
4198
+ return true;
4199
+ }
4200
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
4201
+ return true;
4202
+ }
4203
+ if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
4204
+ return true;
4205
+ }
4206
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
4207
+ return true;
4208
+ }
4455
4209
  return false;
4456
- } break;
4210
+ }
4457
4211
  case GGML_OP_CONCAT:
4458
4212
  {
4459
4213
  ggml_type src0_type = op->src[0]->type;
4460
4214
  return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4461
- } break;
4215
+ }
4462
4216
  case GGML_OP_DUP:
4463
4217
  case GGML_OP_ARGMAX:
4464
4218
  case GGML_OP_NONE:
4465
4219
  case GGML_OP_RESHAPE:
4466
- case GGML_OP_REPEAT:
4467
4220
  case GGML_OP_VIEW:
4468
4221
  case GGML_OP_PERMUTE:
4469
4222
  case GGML_OP_TRANSPOSE:
4470
- case GGML_OP_NORM:
4223
+ return true;
4471
4224
  case GGML_OP_ADD:
4472
4225
  case GGML_OP_ADD1:
4473
- case GGML_OP_LOG:
4474
4226
  case GGML_OP_SUB:
4475
4227
  case GGML_OP_MUL:
4476
4228
  case GGML_OP_DIV:
4477
- case GGML_OP_RMS_NORM:
4478
- case GGML_OP_SCALE:
4229
+ case GGML_OP_REPEAT:
4230
+ return true;
4479
4231
  case GGML_OP_SQR:
4480
4232
  case GGML_OP_SQRT:
4481
4233
  case GGML_OP_SIN:
4482
4234
  case GGML_OP_COS:
4483
4235
  case GGML_OP_CLAMP:
4236
+ case GGML_OP_LOG:
4237
+ #if defined (GGML_SYCL_F16)
4238
+ return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
4239
+ #else
4240
+ return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4241
+ #endif
4242
+ case GGML_OP_NORM:
4243
+ case GGML_OP_RMS_NORM:
4244
+ return true;
4245
+ case GGML_OP_L2_NORM:
4246
+ case GGML_OP_GROUP_NORM:
4247
+ return ggml_is_contiguous(op->src[0]);
4248
+ case GGML_OP_SCALE:
4484
4249
  return true;
4485
4250
  case GGML_OP_CONT:
4486
4251
  return op->src[0]->type != GGML_TYPE_BF16;
@@ -4490,28 +4255,27 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4490
4255
  case GGML_OP_ROPE:
4491
4256
  {
4492
4257
  const int mode = ((const int32_t *) op->op_params)[2];
4493
- if (mode & GGML_ROPE_TYPE_MROPE) {
4258
+ // mode is not used as a bitmask in practice, the various rope type modes are independent implementations
4259
+ if (mode == GGML_ROPE_TYPE_MROPE) {
4494
4260
  return false;
4495
4261
  }
4496
- if (mode & GGML_ROPE_TYPE_VISION) {
4497
- return false;
4498
- }
4499
- return ggml_is_contiguous(op->src[0]);
4262
+ return true;
4500
4263
  }
4501
4264
  case GGML_OP_IM2COL:
4502
- // TODO: add support for the new F32 operations
4503
- return op->src[0]->type == GGML_TYPE_F16;
4265
+ return true;
4266
+ case GGML_OP_UPSCALE:
4267
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4504
4268
  case GGML_OP_POOL_2D:
4505
4269
  case GGML_OP_SUM:
4506
4270
  case GGML_OP_SUM_ROWS:
4507
4271
  case GGML_OP_ARGSORT:
4508
4272
  case GGML_OP_ACC:
4509
- case GGML_OP_GROUP_NORM:
4510
- case GGML_OP_UPSCALE:
4511
4273
  case GGML_OP_PAD:
4512
4274
  case GGML_OP_LEAKY_RELU:
4513
4275
  case GGML_OP_TIMESTEP_EMBEDDING:
4514
4276
  case GGML_OP_RWKV_WKV6:
4277
+ case GGML_OP_RWKV_WKV7:
4278
+ case GGML_OP_GATED_LINEAR_ATTN:
4515
4279
  return true;
4516
4280
  default:
4517
4281
  return false;
@@ -4586,6 +4350,7 @@ static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_bac
4586
4350
 
4587
4351
  static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
4588
4352
  GGML_UNUSED(dev);
4353
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
4589
4354
 
4590
4355
  sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
4591
4356
  SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
@@ -4638,10 +4403,9 @@ static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t re
4638
4403
  static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
4639
4404
  GGML_UNUSED(reg);
4640
4405
 
4641
- // TODO: update to the current function signature
4642
- //if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4643
- // return (void *)ggml_backend_sycl_split_buffer_type;
4644
- //}
4406
+ if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4407
+ return (void *)ggml_backend_sycl_split_buffer_type;
4408
+ }
4645
4409
 
4646
4410
  // SYCL doesn't support registering host memory, left here for reference
4647
4411
  // "ggml_backend_register_host_buffer"