whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -284,22 +284,23 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
284
284
  return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
285
285
  }
286
286
 
287
- __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
- const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
289
- const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
- const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
287
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
288
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
289
+ const sycl::half2 * q8_1_ds, const int & iqs) {
290
+ const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
291
+ const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
291
292
  int v[q4_0_traits::vdr_mmvq];
292
293
  int u[2 * q4_0_traits::vdr_mmvq];
293
294
 
294
- #pragma unroll
295
295
 
296
+ #pragma unroll
296
297
  for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
297
298
  v[i] = get_int_from_uint8(bq4_0, iqs + i);
298
- u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
299
- u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
299
+ u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
300
+ u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi);
300
301
  }
301
302
 
302
- return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
303
+ return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds);
303
304
  };
304
305
  };
305
306
 
@@ -346,24 +347,115 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
346
347
  using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
347
348
  using q4_k_traits = typename q4_k_block::traits;
348
349
 
349
- float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350
- const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
351
- const int ib = ibx_offset / (QK_K / 2);
350
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
351
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
352
+ const sycl::half2 * q8_1_ds, const int & iqs) {
353
+ const int ib = ibx_offset.first / (QK_K / 2);
352
354
 
353
355
  const uint8_t * base = static_cast<const uint8_t *>(vbq);
354
- const uint8_t * qs = base + ibx_offset;
355
- const int total_qs_bytes = nblocks * (QK_K / 2);
356
- const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
357
- const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
356
+ const uint8_t * qs = base + ibx_offset.first;
357
+ const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
358
+ const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
358
359
 
359
360
  const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
360
361
  const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361
362
  const uint16_t * scales = (const uint16_t *) scs;
362
363
 
363
- return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
364
+ int v[2];
365
+ int u[2 * QR4_K];
366
+ float d8[QR4_K];
367
+
368
+ v[0] = q4[0];
369
+ v[1] = q4[4];
370
+
371
+ uint16_t aux[2];
372
+ const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
373
+ if (j < 2) {
374
+ aux[0] = scales[j + 0] & 0x3f3f;
375
+ aux[1] = scales[j + 2] & 0x3f3f;
376
+ } else {
377
+ aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
378
+ aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
379
+ }
380
+
381
+ const uint8_t * sc = (const uint8_t *) aux;
382
+ const uint8_t * m = sc + 2;
383
+
384
+ for (int i = 0; i < QR4_K; ++i) {
385
+ const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
386
+ sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
387
+
388
+ d8[i] = ds_values[0];
389
+
390
+ const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
391
+ u[2 * i + 0] = q8[0];
392
+ u[2 * i + 1] = q8[4];
393
+ }
394
+
395
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8);
364
396
  }
365
397
  };
366
398
 
399
+ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
400
+ static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
401
+
402
+ using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
403
+ using q6_k_traits = typename q6_k_block::traits;
404
+
405
+ __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
406
+ const int8_t * __restrict__ scales, const float d,
407
+ const float * __restrict__ d8) {
408
+ float sumf = 0.0f;
409
+
410
+ #pragma unroll
411
+ for (int i = 0; i < QR6_K; ++i) {
412
+ const int sc = scales[4 * i];
413
+
414
+ const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
415
+
416
+ const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
417
+
418
+ const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
419
+ dpct::sub_sat()); // vi = (vil | vih) - 32
420
+
421
+ sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
422
+ }
423
+
424
+ return d * sumf;
425
+ }
426
+
427
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
428
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
429
+ const int iqs) {
430
+ const int ib = ibx_offset.first / (QK_K / 2);
431
+
432
+ const uint8_t * base = static_cast<const uint8_t *>(vbq);
433
+ const uint8_t * ql = base + ibx_offset.first;
434
+ const uint8_t * qh = base + ibx_offset.second;
435
+ const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
436
+ const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
437
+
438
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
439
+ const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
440
+ const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
441
+
442
+ const int vl = get_int_from_uint8(ql, iqs);
443
+ const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
444
+
445
+ const int8_t * scs = scales + scale_offset;
446
+
447
+ int u[QR6_K];
448
+ float d8[QR6_K];
449
+
450
+ #pragma unroll
451
+ for (int i = 0; i < QR6_K; ++i) {
452
+ u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
453
+ const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
454
+ d8[i] = ds_values[0];
455
+ }
456
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
457
+ }
458
+ };
367
459
  #define VDR_Q4_0_Q8_1_MMVQ 2
368
460
  #define VDR_Q4_0_Q8_1_MMQ 4
369
461
 
@@ -207,12 +207,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
207
207
 
208
208
  // Submit kernel
209
209
  if (C / H == WKV_BLOCK_SIZE) {
210
- stream->submit([&](sycl::handler& cgh) {
210
+ sycl_launch(stream, [&](sycl::handler & cgh) {
211
211
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
212
212
 
213
- cgh.parallel_for(
214
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
215
- [=](sycl::nd_item<3> item_ct1) {
213
+ sycl_parallel_for(
214
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
216
215
  rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
217
216
  B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
218
217
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -220,12 +219,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
220
219
  });
221
220
  });
222
221
  } else {
223
- stream->submit([&](sycl::handler& cgh) {
222
+ sycl_launch(stream, [&](sycl::handler & cgh) {
224
223
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
225
224
 
226
- cgh.parallel_for(
227
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
228
- [=](sycl::nd_item<3> item_ct1) {
225
+ sycl_parallel_for(
226
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
229
227
  rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
230
228
  B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
231
229
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -264,12 +262,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
264
262
 
265
263
  // Submit kernel
266
264
  if (C / H == WKV_BLOCK_SIZE) {
267
- stream->submit([&](sycl::handler& cgh) {
265
+ sycl_launch(stream, [&](sycl::handler & cgh) {
268
266
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
269
267
 
270
- cgh.parallel_for(
271
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
272
- [=](sycl::nd_item<3> item_ct1) {
268
+ sycl_parallel_for(
269
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
273
270
  rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
274
271
  B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
275
272
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -277,12 +274,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
277
274
  });
278
275
  });
279
276
  } else {
280
- stream->submit([&](sycl::handler& cgh) {
277
+ sycl_launch(stream, [&](sycl::handler & cgh) {
281
278
  sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
282
279
 
283
- cgh.parallel_for(
284
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
285
- [=](sycl::nd_item<3> item_ct1) {
280
+ sycl_parallel_for(
281
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
286
282
  rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
287
283
  B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
288
284
  item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -49,15 +49,7 @@ if (Vulkan_FOUND)
49
49
  ../../include/ggml-vulkan.h
50
50
  )
51
51
 
52
- set(VULKAN_SHADER_GEN_CMAKE_ARGS
53
- -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
54
- -DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
55
- )
56
-
57
- set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "")
58
- if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo")
59
- list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE})
60
- endif()
52
+ set(VULKAN_SHADER_GEN_CMAKE_ARGS "")
61
53
 
62
54
  # Test all shader extensions
63
55
  test_shader_extension_support(
@@ -107,10 +99,7 @@ if (Vulkan_FOUND)
107
99
 
108
100
  if (GGML_VULKAN_SHADER_DEBUG_INFO)
109
101
  add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
110
- endif()
111
-
112
- if (GGML_VULKAN_PERF)
113
- add_compile_definitions(GGML_VULKAN_PERF)
102
+ list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON)
114
103
  endif()
115
104
 
116
105
  if (GGML_VULKAN_VALIDATE)
@@ -140,42 +129,54 @@ if (Vulkan_FOUND)
140
129
  set(HOST_CMAKE_TOOLCHAIN_FILE "")
141
130
  endif()
142
131
 
143
- # Always use ExternalProject_Add approach
144
132
  include(ExternalProject)
145
133
 
146
- # Add toolchain file if cross-compiling
147
134
  if (CMAKE_CROSSCOMPILING)
148
135
  list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
149
136
  message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
150
137
  endif()
151
138
 
152
- # Native build through ExternalProject_Add
153
139
  ExternalProject_Add(
154
140
  vulkan-shaders-gen
155
141
  SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
156
- CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS}
157
- BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS}
158
- INSTALL_COMMAND ${CMAKE_COMMAND} --install .
159
- INSTALL_DIR ${CMAKE_BINARY_DIR}
142
+ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>
143
+ -DCMAKE_INSTALL_BINDIR=.
144
+ -DCMAKE_BUILD_TYPE=$<CONFIG>
145
+ ${VULKAN_SHADER_GEN_CMAKE_ARGS}
146
+
147
+ BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>
148
+ BUILD_ALWAYS TRUE
149
+
150
+ # NOTE: When DESTDIR is set using Makefile generators and
151
+ # "make install" triggers the build step, vulkan-shaders-gen
152
+ # would be installed into the DESTDIR prefix, so it is unset
153
+ # to ensure that does not happen.
154
+
155
+ INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR
156
+ ${CMAKE_COMMAND} --install . --config $<CONFIG>
160
157
  )
161
- ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
162
158
 
163
159
  set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
164
- set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
165
- set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
166
- set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
167
- set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
168
- set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
160
+ set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
161
+ set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
162
+ set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
163
+ set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
164
+ set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
165
+ set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
169
166
 
170
- file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
171
- set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
167
+ file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
172
168
 
173
- # Add build and install dependencies for all builds
174
- set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
169
+ # Because external projects do not provide source-level tracking,
170
+ # the vulkan-shaders-gen sources need to be explicitly added to
171
+ # ensure that changes will cascade into shader re-generation.
172
+
173
+ file(GLOB _ggml_vk_shaders_gen_sources
174
+ CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.cpp"
175
+ "${_ggml_vk_input_dir}/*.h")
175
176
 
176
177
  add_custom_command(
177
178
  OUTPUT ${_ggml_vk_header}
178
- ${_ggml_vk_source}
179
+ ${_ggml_vk_source}
179
180
 
180
181
  COMMAND ${_ggml_vk_genshaders_cmd}
181
182
  --glslc ${Vulkan_GLSLC_EXECUTABLE}
@@ -185,7 +186,10 @@ if (Vulkan_FOUND)
185
186
  --target-cpp ${_ggml_vk_source}
186
187
  --no-clean
187
188
 
188
- DEPENDS ${_ggml_vk_shader_deps}
189
+ DEPENDS ${_ggml_vk_shader_files}
190
+ ${_ggml_vk_shaders_gen_sources}
191
+ vulkan-shaders-gen
192
+
189
193
  COMMENT "Generate vulkan shaders"
190
194
  )
191
195