whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -208,12 +208,10 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
208
208
  dpct::has_capability_or_fail(stream->get_device(),
209
209
  {sycl::aspect::fp16});
210
210
 
211
- stream->parallel_for(
212
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
213
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
214
- dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
215
- nrows, item_ct1);
216
- });
211
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
212
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
213
+ dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
214
+ });
217
215
  }
218
216
  }
219
217
 
@@ -877,12 +875,11 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
877
875
  dpct::has_capability_or_fail(stream->get_device(),
878
876
  {sycl::aspect::fp16});
879
877
 
880
- stream->parallel_for(
881
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
882
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
883
- dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
884
- vx, y, dst, ncols, nrows, item_ct1);
885
- });
878
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
879
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
880
+ dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
881
+ nrows, item_ct1);
882
+ });
886
883
  }
887
884
  }
888
885
 
@@ -900,12 +897,10 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
900
897
  dpct::has_capability_or_fail(stream->get_device(),
901
898
  {sycl::aspect::fp16});
902
899
 
903
- stream->parallel_for(
904
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
905
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
906
- dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
907
- vx, y, dst, ncols, nrows, item_ct1);
908
- });
900
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
901
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
902
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
903
+ });
909
904
  }
910
905
  }
911
906
 
@@ -921,12 +916,10 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
921
916
  dpct::has_capability_or_fail(stream->get_device(),
922
917
  {sycl::aspect::fp16});
923
918
 
924
- stream->parallel_for(
925
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
926
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
927
- dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
928
- vx, y, dst, ncols, nrows, item_ct1);
929
- });
919
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
920
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
921
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
922
+ });
930
923
  }
931
924
  }
932
925
 
@@ -942,12 +935,10 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
942
935
  dpct::has_capability_or_fail(stream->get_device(),
943
936
  {sycl::aspect::fp16});
944
937
 
945
- stream->parallel_for(
946
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
947
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
948
- dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
949
- vx, y, dst, ncols, nrows, item_ct1);
950
- });
938
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
939
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
940
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
941
+ });
951
942
  }
952
943
  }
953
944
 
@@ -963,12 +954,10 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
963
954
  dpct::has_capability_or_fail(stream->get_device(),
964
955
  {sycl::aspect::fp16});
965
956
 
966
- stream->parallel_for(
967
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
968
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
969
- dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
970
- vx, y, dst, ncols, nrows, item_ct1);
971
- });
957
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
958
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
959
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
960
+ });
972
961
  }
973
962
  }
974
963
 
@@ -984,12 +973,10 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
984
973
  dpct::has_capability_or_fail(stream->get_device(),
985
974
  {sycl::aspect::fp16});
986
975
 
987
- stream->parallel_for(
988
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
989
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
990
- dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
991
- vx, y, dst, ncols, nrows, item_ct1);
992
- });
976
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
977
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
978
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
979
+ });
993
980
  }
994
981
  }
995
982
 
@@ -1002,11 +989,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
1002
989
  const int block_num_y = (nrows + ny - 1) / ny;
1003
990
  const sycl::range<3> block_nums(1, 1, block_num_y);
1004
991
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1005
- stream->parallel_for(
1006
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1007
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1008
- dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
1009
- });
992
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
993
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
994
+ dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
995
+ });
1010
996
  }
1011
997
 
1012
998
  static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
@@ -1018,11 +1004,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
1018
1004
  const int block_num_y = (nrows + ny - 1) / ny;
1019
1005
  const sycl::range<3> block_nums(1, 1, block_num_y);
1020
1006
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1021
- stream->parallel_for(
1022
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1023
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1024
- dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
1025
- });
1007
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
1008
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1009
+ dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
1010
+ });
1026
1011
  }
1027
1012
 
1028
1013
  static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
@@ -1034,11 +1019,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
1034
1019
  const int block_num_y = (nrows + ny - 1) / ny;
1035
1020
  const sycl::range<3> block_nums(1, 1, block_num_y);
1036
1021
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1037
- stream->parallel_for(
1038
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1039
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1040
- dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
1041
- });
1022
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
1023
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1024
+ dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
1025
+ });
1042
1026
  }
1043
1027
 
1044
1028
  static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
@@ -1047,11 +1031,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
1047
1031
  dpct::queue_ptr stream) {
1048
1032
  GGML_ASSERT(ncols % QK_K == 0);
1049
1033
  const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
1050
- stream->parallel_for(
1051
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
1052
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1053
- dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
1054
- });
1034
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
1035
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1036
+ dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
1037
+ });
1055
1038
  }
1056
1039
 
1057
1040
  static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
@@ -1063,11 +1046,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
1063
1046
  const int block_num_y = (nrows + ny - 1) / ny;
1064
1047
  const sycl::range<3> block_nums(1, 1, block_num_y);
1065
1048
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1066
- stream->parallel_for(
1067
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1068
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1069
- dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
1070
- });
1049
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
1050
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1051
+ dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
1052
+ });
1071
1053
  }
1072
1054
 
1073
1055
  void ggml_sycl_op_dequantize_mul_mat_vec(
@@ -13,10 +13,10 @@
13
13
  #ifndef GGML_SYCL_DPCT_HELPER_HPP
14
14
  #define GGML_SYCL_DPCT_HELPER_HPP
15
15
 
16
+ #include <map>
16
17
  #include <sycl/sycl.hpp>
17
18
  #include <sycl/half_type.hpp>
18
19
  #include <syclcompat/math.hpp>
19
- #include <map>
20
20
 
21
21
  #ifdef GGML_SYCL_USE_INTEL_ONEMKL
22
22
  #include <oneapi/mkl.hpp>
@@ -118,6 +118,36 @@ inline auto get_onemath_backend(sycl::queue& queue)
118
118
  #endif
119
119
  }
120
120
 
121
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
122
+ namespace syclex = sycl::ext::oneapi::experimental;
123
+ #endif
124
+
125
+ template <int NR, typename Func>
126
+ __dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
127
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
128
+ syclex::nd_launch(cgh, nd_range, func);
129
+ #else
130
+ cgh.parallel_for(nd_range, func);
131
+ #endif
132
+ }
133
+
134
+ template <int NR, typename Func>
135
+ __dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
136
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
137
+ syclex::nd_launch(*q, nd_range, func);
138
+ #else
139
+ q->parallel_for(nd_range, func);
140
+ #endif
141
+ }
142
+
143
+ template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
144
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
145
+ syclex::submit(*stream, func);
146
+ #else
147
+ stream->submit(func);
148
+ #endif
149
+ }
150
+
121
151
  namespace dpct
122
152
  {
123
153
  typedef sycl::queue *queue_ptr;