whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -1,46 +1,99 @@
1
1
  #include "mmvq.hpp"
2
+
3
+ #include "ggml.h"
4
+ #include "common.hpp"
5
+ #include "quants.hpp"
2
6
  #include "vecdotq.hpp"
3
- #include <cassert>
7
+
8
+ template <typename reorder_vec_dot_q_sycl>
9
+ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
10
+ const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
11
+ using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
12
+ using block_traits = typename block_type::traits;
13
+
14
+ const auto sg = nd_item.get_sub_group();
15
+ const int sg_range = sg.get_group_linear_range();
16
+ const int workgroup_id = nd_item.get_group_linear_id();
17
+ const int sg_id = sg.get_group_linear_id();
18
+ const int row = workgroup_id * sg_range + sg_id;
19
+
20
+ if (row >= nrows) {
21
+ return;
22
+ }
23
+
24
+ const int blocks_per_row = ncols / block_traits::qk;
25
+ constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
26
+ constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
27
+ const int nblocks = nrows * (ncols / block_traits::qk);
28
+
29
+ static_assert(blocks_per_subgroup > 0);
30
+ static_assert(block_elements_per_subgroup > 0);
31
+
32
+ const block_q8_1 * y = (const block_q8_1 *) vy;
33
+
34
+ float partial_sum = 0.0f;
35
+ for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
36
+ const int ibx = row * blocks_per_row + i; // x block index
37
+ // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
38
+ const int bx_offset = block_type::get_block_offset(ibx);
39
+ const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
40
+
41
+ // Y block index that aligns with ibx
42
+ const int iby = i * block_type::block_to_q8_1_ratio();
43
+
44
+ #pragma unroll
45
+ for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
46
+ // x block quant index when casting the quants to int
47
+ const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
48
+
49
+ partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
50
+ }
51
+ }
52
+
53
+ auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
54
+
55
+ if (sg.leader()) {
56
+ dst[row] = sum;
57
+ }
58
+ }
4
59
 
5
60
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
6
- static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
7
- const sycl::nd_item<3> &item_ct1) {
8
- const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
9
- item_ct1.get_local_id(1);
61
+ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
62
+ const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
63
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
10
64
 
11
65
  if (row >= nrows) {
12
66
  return;
13
67
  }
14
68
 
15
- const int blocks_per_row = ncols / qk;
16
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
17
- assert(blocks_per_warp>0);
69
+ const int blocks_per_row = ncols / qk;
70
+ constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; // Ensuring blocks_per_warp > 0
18
71
 
19
- // partial sum for each thread
72
+ assert(blocks_per_warp > 0);
73
+
74
+ // partial sum for each thread
20
75
  float tmp = 0.0f;
21
76
 
22
- const block_q_t * x = (const block_q_t *) vx;
77
+ const block_q_t * x = (const block_q_t *) vx;
23
78
  const block_q8_1 * y = (const block_q8_1 *) vy;
24
79
 
25
- for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
26
- i += blocks_per_warp) {
27
- const int ibx = row*blocks_per_row + i; // x block index
80
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
81
+ const int ibx = row * blocks_per_row + i; // x block index
28
82
 
29
- const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
83
+ const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
30
84
 
31
- const int iqs =
32
- vdr *
33
- (item_ct1.get_local_id(2) %
34
- (qi / vdr)); // x block quant index when casting the quants to int
85
+ for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
86
+ const int iqs = elem + vdr * (item_ct1.get_local_id(2) %
87
+ (qi / vdr)); // x block quant index when casting the quants to int
35
88
 
36
- tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
89
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
90
+ }
37
91
  }
38
92
 
39
93
  // sum up partial sums and write back result
40
94
  #pragma unroll
41
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
42
- tmp +=
43
- dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
95
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
96
+ tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
44
97
  }
45
98
 
46
99
  if (item_ct1.get_local_id(2) == 0) {
@@ -62,7 +115,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
62
115
  }
63
116
 
64
117
  const int blocks_per_row = ncols / qk;
65
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
118
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
66
119
  assert(blocks_per_warp>0);
67
120
 
68
121
  // partial sum for each thread
@@ -87,7 +140,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
87
140
 
88
141
  // sum up partial sums and write back result
89
142
  #pragma unroll
90
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
143
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
91
144
  tmp +=
92
145
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
93
146
  }
@@ -111,7 +164,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
111
164
  }
112
165
 
113
166
  const int blocks_per_row = ncols / qk;
114
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
167
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
115
168
  assert(blocks_per_warp>0);
116
169
  // partial sum for each thread
117
170
  float tmp = 0.0f;
@@ -135,7 +188,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
135
188
 
136
189
  // sum up partial sums and write back result
137
190
  #pragma unroll
138
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
191
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
139
192
  tmp +=
140
193
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
141
194
  }
@@ -159,7 +212,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
159
212
  }
160
213
 
161
214
  const int blocks_per_row = ncols / qk;
162
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
215
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
163
216
  assert(blocks_per_warp>0);
164
217
  // partial sum for each thread
165
218
  float tmp = 0.0f;
@@ -183,7 +236,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
183
236
 
184
237
  // sum up partial sums and write back result
185
238
  #pragma unroll
186
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
239
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
187
240
  tmp +=
188
241
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
189
242
  }
@@ -207,7 +260,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
207
260
  }
208
261
 
209
262
  const int blocks_per_row = ncols / qk;
210
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
263
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
211
264
  assert(blocks_per_warp>0);
212
265
  // partial sum for each thread
213
266
  float tmp = 0.0f;
@@ -231,7 +284,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
231
284
 
232
285
  // sum up partial sums and write back result
233
286
  #pragma unroll
234
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
287
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
235
288
  tmp +=
236
289
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
237
290
  }
@@ -255,7 +308,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
255
308
  }
256
309
 
257
310
  const int blocks_per_row = ncols / qk;
258
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
311
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
259
312
  assert(blocks_per_warp>0);
260
313
  // partial sum for each thread
261
314
  float tmp = 0.0f;
@@ -279,7 +332,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
279
332
 
280
333
  // sum up partial sums and write back result
281
334
  #pragma unroll
282
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
335
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
283
336
  tmp +=
284
337
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
285
338
  }
@@ -303,7 +356,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
303
356
  }
304
357
 
305
358
  const int blocks_per_row = ncols / qk;
306
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
359
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
307
360
  assert(blocks_per_warp>0);
308
361
  // partial sum for each thread
309
362
  float tmp = 0.0f;
@@ -327,7 +380,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
327
380
 
328
381
  // sum up partial sums and write back result
329
382
  #pragma unroll
330
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
383
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
331
384
  tmp +=
332
385
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
333
386
  }
@@ -351,7 +404,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
351
404
  }
352
405
 
353
406
  const int blocks_per_row = ncols / qk;
354
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
407
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
355
408
  assert(blocks_per_warp>0);
356
409
  // partial sum for each thread
357
410
  float tmp = 0.0f;
@@ -375,7 +428,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
375
428
 
376
429
  // sum up partial sums and write back result
377
430
  #pragma unroll
378
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
431
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
379
432
  tmp +=
380
433
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
381
434
  }
@@ -399,7 +452,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
399
452
  }
400
453
 
401
454
  const int blocks_per_row = ncols / qk;
402
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
455
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
403
456
  assert(blocks_per_warp>0);
404
457
  // partial sum for each thread
405
458
  float tmp = 0.0f;
@@ -423,7 +476,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
423
476
 
424
477
  // sum up partial sums and write back result
425
478
  #pragma unroll
426
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
479
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
427
480
  tmp +=
428
481
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
429
482
  }
@@ -448,7 +501,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
448
501
  }
449
502
 
450
503
  const int blocks_per_row = ncols / qk;
451
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
504
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
452
505
  assert(blocks_per_warp>0);
453
506
  // partial sum for each thread
454
507
  float tmp = 0.0f;
@@ -472,7 +525,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
472
525
 
473
526
  // sum up partial sums and write back result
474
527
  #pragma unroll
475
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
528
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
476
529
  tmp +=
477
530
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
478
531
  }
@@ -482,26 +535,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
482
535
  }
483
536
  }
484
537
 
485
- static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
486
- float *dst, const int ncols,
487
- const int nrows,
538
+ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
539
+ const int nrows, dpct::queue_ptr stream) {
540
+ GGML_ASSERT(ncols % QK4_0 == 0);
541
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
542
+ constexpr size_t num_subgroups = 16;
543
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
544
+
545
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
546
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
547
+
548
+ stream->submit([&](sycl::handler & cgh) {
549
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
550
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
551
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
552
+ nd_item);
553
+ });
554
+ });
555
+ }
556
+
557
+ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
488
558
  dpct::queue_ptr stream) {
489
559
  GGML_ASSERT(ncols % QK4_0 == 0);
490
560
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
491
561
  const sycl::range<3> block_nums(1, 1, block_num_y);
492
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
493
- {
494
-
495
- stream->submit([&](sycl::handler &cgh) {
562
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
496
563
 
497
- cgh.parallel_for(
498
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
499
- [=](sycl::nd_item<3> item_ct1)
500
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
501
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
502
- VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
503
- vx, vy, dst, ncols, nrows, item_ct1);
504
- });
564
+ {
565
+ stream->submit([&](sycl::handler & cgh) {
566
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
567
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
568
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
569
+ vx, vy, dst, ncols, nrows, item_ct1);
570
+ });
505
571
  });
506
572
  }
507
573
  }
@@ -513,7 +579,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
513
579
  GGML_ASSERT(ncols % QK4_1 == 0);
514
580
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
515
581
  const sycl::range<3> block_nums(1, 1, block_num_y);
516
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
582
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
517
583
  {
518
584
 
519
585
  stream->submit([&](sycl::handler &cgh) {
@@ -521,7 +587,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
521
587
  cgh.parallel_for(
522
588
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
523
589
  [=](sycl::nd_item<3> item_ct1)
524
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
590
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
525
591
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
526
592
  VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
527
593
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -537,7 +603,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
537
603
  GGML_ASSERT(ncols % QK5_0 == 0);
538
604
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
539
605
  const sycl::range<3> block_nums(1, 1, block_num_y);
540
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
606
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
541
607
  {
542
608
 
543
609
  stream->submit([&](sycl::handler &cgh) {
@@ -545,7 +611,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
545
611
  cgh.parallel_for(
546
612
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
547
613
  [=](sycl::nd_item<3> item_ct1)
548
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
614
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
549
615
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
550
616
  VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
551
617
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -561,7 +627,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
561
627
  GGML_ASSERT(ncols % QK5_1 == 0);
562
628
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
563
629
  const sycl::range<3> block_nums(1, 1, block_num_y);
564
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
630
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
565
631
  {
566
632
 
567
633
  stream->submit([&](sycl::handler &cgh) {
@@ -569,7 +635,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
569
635
  cgh.parallel_for(
570
636
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
571
637
  [=](sycl::nd_item<3> item_ct1)
572
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
638
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
573
639
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
574
640
  VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
575
641
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -585,7 +651,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
585
651
  GGML_ASSERT(ncols % QK8_0 == 0);
586
652
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
587
653
  const sycl::range<3> block_nums(1, 1, block_num_y);
588
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
654
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
589
655
  {
590
656
 
591
657
  stream->submit([&](sycl::handler &cgh) {
@@ -593,7 +659,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
593
659
  cgh.parallel_for(
594
660
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
595
661
  [=](sycl::nd_item<3> item_ct1)
596
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
662
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
597
663
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
598
664
  VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
599
665
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -609,7 +675,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
609
675
  GGML_ASSERT(ncols % QK_K == 0);
610
676
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
611
677
  const sycl::range<3> block_nums(1, 1, block_num_y);
612
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
678
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
613
679
  {
614
680
 
615
681
  stream->submit([&](sycl::handler &cgh) {
@@ -617,7 +683,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
617
683
  cgh.parallel_for(
618
684
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
619
685
  [=](sycl::nd_item<3> item_ct1)
620
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
686
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
621
687
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
622
688
  VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
623
689
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -633,7 +699,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
633
699
  GGML_ASSERT(ncols % QK_K == 0);
634
700
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
635
701
  const sycl::range<3> block_nums(1, 1, block_num_y);
636
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
702
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
637
703
  {
638
704
 
639
705
  stream->submit([&](sycl::handler &cgh) {
@@ -641,7 +707,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
641
707
  cgh.parallel_for(
642
708
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
643
709
  [=](sycl::nd_item<3> item_ct1)
644
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
710
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
645
711
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
646
712
  VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
647
713
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -657,7 +723,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
657
723
  GGML_ASSERT(ncols % QK_K == 0);
658
724
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
659
725
  const sycl::range<3> block_nums(1, 1, block_num_y);
660
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
726
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
661
727
  {
662
728
 
663
729
  stream->submit([&](sycl::handler &cgh) {
@@ -665,7 +731,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
665
731
  cgh.parallel_for(
666
732
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
667
733
  [=](sycl::nd_item<3> item_ct1)
668
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
734
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
669
735
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
670
736
  VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
671
737
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -674,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
674
740
  }
675
741
  }
676
742
 
743
+ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
744
+ const int nrows, dpct::queue_ptr stream) {
745
+ GGML_ASSERT(ncols % QK_K == 0);
746
+
747
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
748
+ constexpr size_t num_subgroups = 16;
749
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
750
+
751
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
752
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
753
+
754
+ stream->submit([&](sycl::handler & cgh) {
755
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
756
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
757
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
758
+ nrows, nd_item);
759
+ });
760
+ });
761
+ }
762
+
763
+
677
764
  static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
678
765
  float *dst, const int ncols,
679
766
  const int nrows,
@@ -681,7 +768,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
681
768
  GGML_ASSERT(ncols % QK_K == 0);
682
769
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
683
770
  const sycl::range<3> block_nums(1, 1, block_num_y);
684
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
771
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
685
772
  {
686
773
 
687
774
  stream->submit([&](sycl::handler &cgh) {
@@ -689,7 +776,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
689
776
  cgh.parallel_for(
690
777
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
691
778
  [=](sycl::nd_item<3> item_ct1)
692
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
779
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
693
780
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
694
781
  VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
695
782
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -705,7 +792,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
705
792
  GGML_ASSERT(ncols % QK_K == 0);
706
793
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
707
794
  const sycl::range<3> block_nums(1, 1, block_num_y);
708
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
795
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
709
796
  {
710
797
 
711
798
  stream->submit([&](sycl::handler &cgh) {
@@ -713,7 +800,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
713
800
  cgh.parallel_for(
714
801
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
715
802
  [=](sycl::nd_item<3> item_ct1)
716
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
803
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
717
804
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
718
805
  VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
719
806
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -730,13 +817,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
730
817
  GGML_ASSERT(ncols % QK_K == 0);
731
818
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
732
819
  const sycl::range<3> block_nums(1, 1, block_num_y);
733
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
820
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
734
821
  {
735
822
  stream->submit([&](sycl::handler &cgh) {
736
823
  cgh.parallel_for(
737
824
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
738
825
  [=](sycl::nd_item<3> item_ct1)
739
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
826
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
740
827
  mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
741
828
  vx, vy, dst, ncols, nrows, item_ct1);
742
829
  });
@@ -751,13 +838,13 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
751
838
  GGML_ASSERT(ncols % QK_K == 0);
752
839
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
753
840
  const sycl::range<3> block_nums(1, 1, block_num_y);
754
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
841
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
755
842
  {
756
843
  stream->submit([&](sycl::handler & cgh) {
757
844
  cgh.parallel_for(
758
845
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
759
846
  [=](sycl::nd_item<3> item_ct1)
760
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
847
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
761
848
  mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
762
849
  vx, vy, dst, ncols, nrows, item_ct1);
763
850
  });
@@ -772,14 +859,14 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
772
859
  GGML_ASSERT(ncols % QK_K == 0);
773
860
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
774
861
  const sycl::range<3> block_nums(1, 1, block_num_y);
775
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
862
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
776
863
  {
777
864
 
778
865
  stream->submit([&](sycl::handler &cgh) {
779
866
  cgh.parallel_for(
780
867
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
781
868
  [=](sycl::nd_item<3> item_ct1)
782
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
869
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
783
870
  mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
784
871
  vx, vy, dst, ncols, nrows, item_ct1);
785
872
  });
@@ -794,14 +881,14 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
794
881
  GGML_ASSERT(ncols % QK_K == 0);
795
882
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
796
883
  const sycl::range<3> block_nums(1, 1, block_num_y);
797
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
884
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
798
885
  {
799
886
 
800
887
  stream->submit([&](sycl::handler &cgh) {
801
888
  cgh.parallel_for(
802
889
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
803
890
  [=](sycl::nd_item<3> item_ct1)
804
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
891
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
805
892
  mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
806
893
  vx, vy, dst, ncols, nrows, item_ct1);
807
894
  });
@@ -816,14 +903,14 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
816
903
  GGML_ASSERT(ncols % QK_K == 0);
817
904
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
818
905
  const sycl::range<3> block_nums(1, 1, block_num_y);
819
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
906
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
820
907
  {
821
908
 
822
909
  stream->submit([&](sycl::handler &cgh) {
823
910
  cgh.parallel_for(
824
911
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
825
912
  [=](sycl::nd_item<3> item_ct1)
826
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
913
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
827
914
  mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
828
915
  vx, vy, dst, ncols, nrows, item_ct1);
829
916
  });
@@ -838,14 +925,14 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
838
925
  GGML_ASSERT(ncols % QK_K == 0);
839
926
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
840
927
  const sycl::range<3> block_nums(1, 1, block_num_y);
841
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
928
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
842
929
  {
843
930
 
844
931
  stream->submit([&](sycl::handler &cgh) {
845
932
  cgh.parallel_for(
846
933
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
847
934
  [=](sycl::nd_item<3> item_ct1)
848
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
935
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
849
936
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
850
937
  vx, vy, dst, ncols, nrows, item_ct1);
851
938
  });
@@ -860,13 +947,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
860
947
  GGML_ASSERT(ncols % QK_K == 0);
861
948
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
862
949
  const sycl::range<3> block_nums(1, 1, block_num_y);
863
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
950
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
864
951
  {
865
952
  stream->submit([&](sycl::handler &cgh) {
866
953
  cgh.parallel_for(
867
954
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
868
955
  [=](sycl::nd_item<3> item_ct1)
869
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
956
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
870
957
  mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
871
958
  vx, vy, dst, ncols, nrows, item_ct1);
872
959
  });
@@ -881,14 +968,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
881
968
  GGML_ASSERT(ncols % QK4_NL == 0);
882
969
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
883
970
  const sycl::range<3> block_nums(1, 1, block_num_y);
884
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
971
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
885
972
  {
886
973
 
887
974
  stream->submit([&](sycl::handler &cgh) {
888
975
  cgh.parallel_for(
889
976
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
890
977
  [=](sycl::nd_item<3> item_ct1)
891
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
978
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
892
979
  mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
893
980
  vx, vy, dst, ncols, nrows, item_ct1);
894
981
  });
@@ -903,14 +990,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
903
990
  GGML_ASSERT(ncols % QK_K == 0);
904
991
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
905
992
  const sycl::range<3> block_nums(1, 1, block_num_y);
906
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
993
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
907
994
  {
908
995
 
909
996
  stream->submit([&](sycl::handler &cgh) {
910
997
  cgh.parallel_for(
911
998
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
912
999
  [=](sycl::nd_item<3> item_ct1)
913
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1000
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
914
1001
  mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
915
1002
  vx, vy, dst, ncols, nrows, item_ct1);
916
1003
  });
@@ -918,94 +1005,102 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
918
1005
  }
919
1006
  }
920
1007
 
921
- void ggml_sycl_op_mul_mat_vec_q(
922
- ggml_backend_sycl_context & ctx,
923
- const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
924
- const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
925
- float *dst_dd_i, const int64_t row_low, const int64_t row_high,
926
- const int64_t src1_ncols, const int64_t src1_padded_col_size,
927
- const dpct::queue_ptr &stream) {
928
-
1008
+ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
1009
+ ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1010
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
1011
+ const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
1012
+ const dpct::queue_ptr & stream) {
929
1013
  const int64_t ne10 = src1->ne[0];
930
1014
  GGML_ASSERT(ne10 % QK8_1 == 0);
931
1015
 
932
- const int64_t ne00 = src0->ne[0];
1016
+ const int64_t ne00 = src0->ne[0];
933
1017
  const int64_t row_diff = row_high - row_low;
934
1018
 
935
1019
  int id;
936
- SYCL_CHECK(
937
- CHECK_TRY_ERROR(id = get_current_device_id()));
1020
+ SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
938
1021
  const size_t q8_1_ts = sizeof(block_q8_1);
939
1022
  const size_t q8_1_bs = QK8_1;
940
1023
  // the main device has a larger memory buffer to hold the results from all GPUs
941
1024
  // nrows_dst == nrows of the matrix that the kernel writes into
942
1025
 
943
- for (int i = 0; i < src1_ncols; i++)
944
- {
1026
+ for (int i = 0; i < src1_ncols; i++) {
945
1027
  const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
946
- const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
947
- float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
1028
+ const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
1029
+ float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
948
1030
  switch (src0->type) {
949
- case GGML_TYPE_Q4_0:
950
- mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
951
- break;
952
- case GGML_TYPE_Q4_1:
953
- mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
954
- break;
955
- case GGML_TYPE_Q5_0:
956
- mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
957
- break;
958
- case GGML_TYPE_Q5_1:
959
- mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
960
- break;
961
- case GGML_TYPE_Q8_0:
962
- mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
963
- break;
964
- case GGML_TYPE_Q2_K:
965
- mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
966
- break;
967
- case GGML_TYPE_Q3_K:
968
- mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
969
- break;
970
- case GGML_TYPE_Q4_K:
971
- mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
972
- break;
973
- case GGML_TYPE_Q5_K:
974
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
975
- break;
976
- case GGML_TYPE_Q6_K:
977
- mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
978
- break;
979
- case GGML_TYPE_IQ1_S:
980
- mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
981
- break;
982
- case GGML_TYPE_IQ1_M:
983
- mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
984
- break;
985
- case GGML_TYPE_IQ2_XXS:
986
- mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
987
- break;
988
- case GGML_TYPE_IQ2_XS:
989
- mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
990
- break;
991
- case GGML_TYPE_IQ2_S:
992
- mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
993
- break;
994
- case GGML_TYPE_IQ3_XXS:
995
- mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
996
- break;
997
- case GGML_TYPE_IQ3_S:
998
- mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
999
- break;
1000
- case GGML_TYPE_IQ4_NL:
1001
- mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1002
- break;
1003
- case GGML_TYPE_IQ4_XS:
1004
- mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1005
- break;
1006
- default:
1007
- GGML_ABORT("fatal error");
1008
- break;
1031
+ case GGML_TYPE_Q4_0:
1032
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1033
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1034
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
1035
+ reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1036
+ } else {
1037
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
1038
+ mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1039
+ }
1040
+ break;
1041
+ case GGML_TYPE_Q4_1:
1042
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1043
+ break;
1044
+ case GGML_TYPE_Q5_0:
1045
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1046
+ break;
1047
+ case GGML_TYPE_Q5_1:
1048
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1049
+ break;
1050
+ case GGML_TYPE_Q8_0:
1051
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1052
+ break;
1053
+ case GGML_TYPE_Q2_K:
1054
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1055
+ break;
1056
+ case GGML_TYPE_Q3_K:
1057
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1058
+ break;
1059
+ case GGML_TYPE_Q4_K:
1060
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1061
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1062
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n");
1063
+ reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1064
+ } else {
1065
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n");
1066
+ mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1067
+ }
1068
+ break;
1069
+ case GGML_TYPE_Q5_K:
1070
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1071
+ break;
1072
+ case GGML_TYPE_Q6_K:
1073
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1074
+ break;
1075
+ case GGML_TYPE_IQ1_S:
1076
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1077
+ break;
1078
+ case GGML_TYPE_IQ1_M:
1079
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1080
+ break;
1081
+ case GGML_TYPE_IQ2_XXS:
1082
+ mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1083
+ break;
1084
+ case GGML_TYPE_IQ2_XS:
1085
+ mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1086
+ break;
1087
+ case GGML_TYPE_IQ2_S:
1088
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1089
+ break;
1090
+ case GGML_TYPE_IQ3_XXS:
1091
+ mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1092
+ break;
1093
+ case GGML_TYPE_IQ3_S:
1094
+ mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1095
+ break;
1096
+ case GGML_TYPE_IQ4_NL:
1097
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1098
+ break;
1099
+ case GGML_TYPE_IQ4_XS:
1100
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1101
+ break;
1102
+ default:
1103
+ GGML_ABORT("fatal error");
1009
1104
  }
1010
1105
  }
1011
1106
  GGML_UNUSED(src1);