whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -29,24 +29,23 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
29
29
  static_assert(blocks_per_subgroup > 0);
30
30
  static_assert(block_elements_per_subgroup > 0);
31
31
 
32
- const block_q8_1 * y = (const block_q8_1 *) vy;
33
-
34
32
  float partial_sum = 0.0f;
35
33
  for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
36
- const int ibx = row * blocks_per_row + i; // x block index
37
- // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
38
- const int bx_offset = block_type::get_block_offset(ibx);
39
- const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
34
+ const int ibx = row * blocks_per_row + i; // x block index
40
35
 
36
+ const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
37
+ const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
41
38
  // Y block index that aligns with ibx
42
39
  const int iby = i * block_type::block_to_q8_1_ratio();
40
+ const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
41
+ const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
43
42
 
44
43
  #pragma unroll
45
44
  for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
46
45
  // x block quant index when casting the quants to int
47
46
  const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
48
47
 
49
- partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
48
+ partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
50
49
  }
51
50
  }
52
51
 
@@ -545,12 +544,12 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
545
544
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
546
545
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
547
546
 
548
- stream->submit([&](sycl::handler & cgh) {
549
- cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
550
- [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
551
- mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
552
- nd_item);
553
- });
547
+ sycl_launch(stream, [&](sycl::handler & cgh) {
548
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
549
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
550
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
551
+ nd_item);
552
+ });
554
553
  });
555
554
  }
556
555
 
@@ -562,12 +561,12 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
562
561
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
563
562
 
564
563
  {
565
- stream->submit([&](sycl::handler & cgh) {
566
- cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
567
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
568
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
569
- vx, vy, dst, ncols, nrows, item_ct1);
570
- });
564
+ sycl_launch(stream, [&](sycl::handler & cgh) {
565
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
566
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
567
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
568
+ vx, vy, dst, ncols, nrows, item_ct1);
569
+ });
571
570
  });
572
571
  }
573
572
  }
@@ -581,17 +580,12 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
581
580
  const sycl::range<3> block_nums(1, 1, block_num_y);
582
581
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
583
582
  {
584
-
585
- stream->submit([&](sycl::handler &cgh) {
586
-
587
- cgh.parallel_for(
588
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
589
- [=](sycl::nd_item<3> item_ct1)
590
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
591
- mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
592
- VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
593
- vx, vy, dst, ncols, nrows, item_ct1);
594
- });
583
+ sycl_launch(stream, [&](sycl::handler & cgh) {
584
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
585
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
586
+ mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
587
+ vx, vy, dst, ncols, nrows, item_ct1);
588
+ });
595
589
  });
596
590
  }
597
591
  }
@@ -605,17 +599,12 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
605
599
  const sycl::range<3> block_nums(1, 1, block_num_y);
606
600
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
607
601
  {
608
-
609
- stream->submit([&](sycl::handler &cgh) {
610
-
611
- cgh.parallel_for(
612
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
613
- [=](sycl::nd_item<3> item_ct1)
614
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
615
- mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
616
- VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
617
- vx, vy, dst, ncols, nrows, item_ct1);
618
- });
602
+ sycl_launch(stream, [&](sycl::handler & cgh) {
603
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
604
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
605
+ mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
606
+ vx, vy, dst, ncols, nrows, item_ct1);
607
+ });
619
608
  });
620
609
  }
621
610
  }
@@ -629,17 +618,12 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
629
618
  const sycl::range<3> block_nums(1, 1, block_num_y);
630
619
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
631
620
  {
632
-
633
- stream->submit([&](sycl::handler &cgh) {
634
-
635
- cgh.parallel_for(
636
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
637
- [=](sycl::nd_item<3> item_ct1)
638
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
639
- mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
640
- VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
641
- vx, vy, dst, ncols, nrows, item_ct1);
642
- });
621
+ sycl_launch(stream, [&](sycl::handler & cgh) {
622
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
623
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
624
+ mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
625
+ vx, vy, dst, ncols, nrows, item_ct1);
626
+ });
643
627
  });
644
628
  }
645
629
  }
@@ -653,17 +637,12 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
653
637
  const sycl::range<3> block_nums(1, 1, block_num_y);
654
638
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
655
639
  {
656
-
657
- stream->submit([&](sycl::handler &cgh) {
658
-
659
- cgh.parallel_for(
660
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
661
- [=](sycl::nd_item<3> item_ct1)
662
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
663
- mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
664
- VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
665
- vx, vy, dst, ncols, nrows, item_ct1);
666
- });
640
+ sycl_launch(stream, [&](sycl::handler & cgh) {
641
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
642
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
643
+ mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
644
+ vx, vy, dst, ncols, nrows, item_ct1);
645
+ });
667
646
  });
668
647
  }
669
648
  }
@@ -677,17 +656,12 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
677
656
  const sycl::range<3> block_nums(1, 1, block_num_y);
678
657
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
679
658
  {
680
-
681
- stream->submit([&](sycl::handler &cgh) {
682
-
683
- cgh.parallel_for(
684
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
685
- [=](sycl::nd_item<3> item_ct1)
686
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
687
- mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
688
- VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
689
- vx, vy, dst, ncols, nrows, item_ct1);
690
- });
659
+ sycl_launch(stream, [&](sycl::handler & cgh) {
660
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
661
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
662
+ mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
663
+ vx, vy, dst, ncols, nrows, item_ct1);
664
+ });
691
665
  });
692
666
  }
693
667
  }
@@ -701,17 +675,12 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
701
675
  const sycl::range<3> block_nums(1, 1, block_num_y);
702
676
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
703
677
  {
704
-
705
- stream->submit([&](sycl::handler &cgh) {
706
-
707
- cgh.parallel_for(
708
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
709
- [=](sycl::nd_item<3> item_ct1)
710
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
711
- mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
712
- VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
713
- vx, vy, dst, ncols, nrows, item_ct1);
714
- });
678
+ sycl_launch(stream, [&](sycl::handler & cgh) {
679
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
680
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
681
+ mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
682
+ vx, vy, dst, ncols, nrows, item_ct1);
683
+ });
715
684
  });
716
685
  }
717
686
  }
@@ -725,17 +694,12 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
725
694
  const sycl::range<3> block_nums(1, 1, block_num_y);
726
695
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
727
696
  {
728
-
729
- stream->submit([&](sycl::handler &cgh) {
730
-
731
- cgh.parallel_for(
732
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
733
- [=](sycl::nd_item<3> item_ct1)
734
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
735
- mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
736
- VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
737
- vx, vy, dst, ncols, nrows, item_ct1);
738
- });
697
+ sycl_launch(stream, [&](sycl::handler & cgh) {
698
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
699
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
700
+ mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
701
+ vx, vy, dst, ncols, nrows, item_ct1);
702
+ });
739
703
  });
740
704
  }
741
705
  }
@@ -751,12 +715,12 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
751
715
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
752
716
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
753
717
 
754
- stream->submit([&](sycl::handler & cgh) {
755
- cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
756
- [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
757
- mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
758
- nrows, nd_item);
759
- });
718
+ sycl_launch(stream, [&](sycl::handler & cgh) {
719
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
720
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
721
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, nrows,
722
+ nd_item);
723
+ });
760
724
  });
761
725
  }
762
726
 
@@ -770,21 +734,34 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
770
734
  const sycl::range<3> block_nums(1, 1, block_num_y);
771
735
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
772
736
  {
773
-
774
- stream->submit([&](sycl::handler &cgh) {
775
-
776
- cgh.parallel_for(
777
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
778
- [=](sycl::nd_item<3> item_ct1)
779
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
780
- mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
781
- VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
782
- vx, vy, dst, ncols, nrows, item_ct1);
783
- });
737
+ sycl_launch(stream, [&](sycl::handler & cgh) {
738
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
739
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
740
+ mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
741
+ vx, vy, dst, ncols, nrows, item_ct1);
742
+ });
784
743
  });
785
744
  }
786
745
  }
787
746
 
747
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
748
+ const int nrows, dpct::queue_ptr stream) {
749
+ GGML_ASSERT(ncols % QK_K == 0);
750
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
751
+ constexpr size_t num_subgroups = 16;
752
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
753
+
754
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
755
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
756
+
757
+ sycl_launch(stream, [&](sycl::handler & cgh) {
758
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
759
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
760
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
761
+ nd_item);
762
+ });
763
+ });
764
+ }
788
765
  static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
789
766
  float *dst, const int ncols,
790
767
  const int nrows,
@@ -794,17 +771,12 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
794
771
  const sycl::range<3> block_nums(1, 1, block_num_y);
795
772
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
796
773
  {
797
-
798
- stream->submit([&](sycl::handler &cgh) {
799
-
800
- cgh.parallel_for(
801
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
802
- [=](sycl::nd_item<3> item_ct1)
803
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
804
- mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
805
- VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
806
- vx, vy, dst, ncols, nrows, item_ct1);
807
- });
774
+ sycl_launch(stream, [&](sycl::handler & cgh) {
775
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
776
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
777
+ mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
778
+ vx, vy, dst, ncols, nrows, item_ct1);
779
+ });
808
780
  });
809
781
  }
810
782
  }
@@ -819,14 +791,12 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
819
791
  const sycl::range<3> block_nums(1, 1, block_num_y);
820
792
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
821
793
  {
822
- stream->submit([&](sycl::handler &cgh) {
823
- cgh.parallel_for(
824
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
825
- [=](sycl::nd_item<3> item_ct1)
826
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
827
- mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
828
- vx, vy, dst, ncols, nrows, item_ct1);
829
- });
794
+ sycl_launch(stream, [&](sycl::handler & cgh) {
795
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
796
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
797
+ mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS / 2, block_iq2_xxs, 1>(vx, vy, dst, ncols,
798
+ nrows, item_ct1);
799
+ });
830
800
  });
831
801
  }
832
802
  }
@@ -840,14 +810,12 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
840
810
  const sycl::range<3> block_nums(1, 1, block_num_y);
841
811
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
842
812
  {
843
- stream->submit([&](sycl::handler & cgh) {
844
- cgh.parallel_for(
845
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
846
- [=](sycl::nd_item<3> item_ct1)
847
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
848
- mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
849
- vx, vy, dst, ncols, nrows, item_ct1);
850
- });
813
+ sycl_launch(stream, [&](sycl::handler & cgh) {
814
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
815
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
816
+ mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS / 2, block_iq2_xs, 1>(vx, vy, dst, ncols,
817
+ nrows, item_ct1);
818
+ });
851
819
  });
852
820
  }
853
821
  }
@@ -861,15 +829,12 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
861
829
  const sycl::range<3> block_nums(1, 1, block_num_y);
862
830
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
863
831
  {
864
-
865
- stream->submit([&](sycl::handler &cgh) {
866
- cgh.parallel_for(
867
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
868
- [=](sycl::nd_item<3> item_ct1)
869
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
870
- mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
871
- vx, vy, dst, ncols, nrows, item_ct1);
872
- });
832
+ sycl_launch(stream, [&](sycl::handler & cgh) {
833
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
834
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
835
+ mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S / 2, block_iq2_s, 1>(vx, vy, dst, ncols, nrows,
836
+ item_ct1);
837
+ });
873
838
  });
874
839
  }
875
840
  }
@@ -883,15 +848,12 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
883
848
  const sycl::range<3> block_nums(1, 1, block_num_y);
884
849
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
885
850
  {
886
-
887
- stream->submit([&](sycl::handler &cgh) {
888
- cgh.parallel_for(
889
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
890
- [=](sycl::nd_item<3> item_ct1)
891
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
892
- mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
893
- vx, vy, dst, ncols, nrows, item_ct1);
894
- });
851
+ sycl_launch(stream, [&](sycl::handler & cgh) {
852
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
853
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
854
+ mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS / 2, block_iq3_xxs, 1>(vx, vy, dst, ncols,
855
+ nrows, item_ct1);
856
+ });
895
857
  });
896
858
  }
897
859
  }
@@ -905,15 +867,12 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
905
867
  const sycl::range<3> block_nums(1, 1, block_num_y);
906
868
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
907
869
  {
908
-
909
- stream->submit([&](sycl::handler &cgh) {
910
- cgh.parallel_for(
911
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
912
- [=](sycl::nd_item<3> item_ct1)
913
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
914
- mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
915
- vx, vy, dst, ncols, nrows, item_ct1);
916
- });
870
+ sycl_launch(stream, [&](sycl::handler & cgh) {
871
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
872
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
873
+ mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S / 2, block_iq3_s, 1>(vx, vy, dst, ncols, nrows,
874
+ item_ct1);
875
+ });
917
876
  });
918
877
  }
919
878
  }
@@ -927,15 +886,12 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
927
886
  const sycl::range<3> block_nums(1, 1, block_num_y);
928
887
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
929
888
  {
930
-
931
- stream->submit([&](sycl::handler &cgh) {
932
- cgh.parallel_for(
933
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
934
- [=](sycl::nd_item<3> item_ct1)
935
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
936
- mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
937
- vx, vy, dst, ncols, nrows, item_ct1);
938
- });
889
+ sycl_launch(stream, [&](sycl::handler & cgh) {
890
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
891
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
892
+ mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(vx, vy, dst, ncols, nrows,
893
+ item_ct1);
894
+ });
939
895
  });
940
896
  }
941
897
  }
@@ -949,14 +905,12 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
949
905
  const sycl::range<3> block_nums(1, 1, block_num_y);
950
906
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
951
907
  {
952
- stream->submit([&](sycl::handler &cgh) {
953
- cgh.parallel_for(
954
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
955
- [=](sycl::nd_item<3> item_ct1)
956
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
957
- mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
958
- vx, vy, dst, ncols, nrows, item_ct1);
959
- });
908
+ sycl_launch(stream, [&](sycl::handler & cgh) {
909
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
910
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
911
+ mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(vx, vy, dst, ncols, nrows,
912
+ item_ct1);
913
+ });
960
914
  });
961
915
  }
962
916
  }
@@ -970,15 +924,12 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
970
924
  const sycl::range<3> block_nums(1, 1, block_num_y);
971
925
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
972
926
  {
973
-
974
- stream->submit([&](sycl::handler &cgh) {
975
- cgh.parallel_for(
976
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
977
- [=](sycl::nd_item<3> item_ct1)
978
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
979
- mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
980
- vx, vy, dst, ncols, nrows, item_ct1);
981
- });
927
+ sycl_launch(stream, [&](sycl::handler & cgh) {
928
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
929
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
930
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(vx, vy, dst, ncols, nrows,
931
+ item_ct1);
932
+ });
982
933
  });
983
934
  }
984
935
  }
@@ -992,15 +943,12 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
992
943
  const sycl::range<3> block_nums(1, 1, block_num_y);
993
944
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
994
945
  {
995
-
996
- stream->submit([&](sycl::handler &cgh) {
997
- cgh.parallel_for(
998
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
999
- [=](sycl::nd_item<3> item_ct1)
1000
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1001
- mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
1002
- vx, vy, dst, ncols, nrows, item_ct1);
1003
- });
946
+ sycl_launch(stream, [&](sycl::handler & cgh) {
947
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
948
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
949
+ mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS / 4, block_iq4_xs, 1>(vx, vy, dst, ncols,
950
+ nrows, item_ct1);
951
+ });
1004
952
  });
1005
953
  }
1006
954
  }
@@ -1070,7 +1018,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1070
1018
  mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1071
1019
  break;
1072
1020
  case GGML_TYPE_Q6_K:
1073
- mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1021
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1022
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1023
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
1024
+ reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1025
+ } else {
1026
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
1027
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1028
+ }
1074
1029
  break;
1075
1030
  case GGML_TYPE_IQ1_S:
1076
1031
  mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);