whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -83,9 +83,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
83
83
 
84
84
  info.devices[i].cc =
85
85
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
86
- info.devices[i].hw_info = get_device_hw_info(&device);
87
- info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch);
88
-
86
+ info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
89
87
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
90
88
  }
91
89
 
@@ -195,7 +193,7 @@ static void ggml_check_sycl() try {
195
193
 
196
194
  if (!initialized) {
197
195
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
198
- g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
196
+ g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
199
197
  g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
200
198
  g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
201
199
  g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
@@ -347,14 +345,15 @@ static enum ggml_status
347
345
  ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
348
346
  ggml_tensor *tensor) try {
349
347
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
350
- debug_print_tensor(": tensor=", tensor, "\n");
348
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
351
349
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
352
350
 
353
351
  if (tensor->view_src != NULL) {
354
352
  assert(tensor->view_src->buffer->buft == buffer->buft);
355
353
  return GGML_STATUS_SUCCESS;
356
354
  }
357
- if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
355
+ if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
356
+ !g_ggml_sycl_disable_optimize) {
358
357
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
359
358
  tensor->extra = extra;
360
359
  ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -384,7 +383,7 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
384
383
  const void *data, size_t offset,
385
384
  size_t size) try {
386
385
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
387
- debug_print_tensor(": tensor=", tensor);
386
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
388
387
  GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
389
388
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
390
389
  ggml_sycl_set_device(ctx->device);
@@ -412,7 +411,7 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
412
411
  void *data, size_t offset,
413
412
  size_t size) try {
414
413
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
415
- debug_print_tensor(": tensor=", tensor);
414
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
416
415
  GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
417
416
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
418
417
 
@@ -443,8 +442,8 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
443
442
  ggml_tensor *dst) try {
444
443
  bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);
445
444
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
446
- debug_print_tensor(": dst=", dst);
447
- debug_print_tensor(" src=", src);
445
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
446
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
448
447
  GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
449
448
  if (is_cpy_supported) {
450
449
  ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
@@ -524,7 +523,7 @@ catch (sycl::exception const &exc) {
524
523
  static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
525
524
  size_t offset, size_t size) {
526
525
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
527
- debug_print_tensor(": tensor=", tensor);
526
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
528
527
  GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
529
528
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
530
529
  SYCL_CHECK(ggml_sycl_set_device(ctx->device));
@@ -804,7 +803,7 @@ static enum ggml_status
804
803
  ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
805
804
  ggml_tensor *tensor) try {
806
805
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
807
- debug_print_tensor(": tensor=", tensor, "\n");
806
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
808
807
  GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
809
808
 
810
809
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
@@ -890,7 +889,7 @@ ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
890
889
  ggml_tensor *tensor, const void *data,
891
890
  size_t offset, size_t size) try {
892
891
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
893
- debug_print_tensor(": tensor=", tensor);
892
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
894
893
  GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
895
894
  // split tensors must always be set in their entirety at once
896
895
  GGML_ASSERT(offset == 0);
@@ -946,7 +945,7 @@ ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
946
945
  const ggml_tensor *tensor, void *data,
947
946
  size_t offset, size_t size) try {
948
947
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
949
- debug_print_tensor(": tensor=", tensor);
948
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
950
949
  GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
951
950
  // split tensors must always be set in their entirety at once
952
951
  GGML_ASSERT(offset == 0);
@@ -1434,6 +1433,59 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1434
1433
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1435
1434
  }
1436
1435
 
1436
+ template <int ElementsPerWI>
1437
+ static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
1438
+ const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
1439
+ /*
1440
+ Quantizes and reorders the resultant q8 tensor in a per row fashion
1441
+ Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1442
+ */
1443
+
1444
+ auto subgroup_id = it.get_group(0);
1445
+ auto wi_id = it.get_local_id(0);
1446
+
1447
+ const int num_blocks_per_row = kx / QK8_1;
1448
+ auto row = subgroup_id / num_blocks_per_row;
1449
+ auto col = subgroup_id % num_blocks_per_row;
1450
+
1451
+ auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
1452
+ auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1453
+
1454
+ auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1455
+ auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
1456
+
1457
+ sycl::vec<float, ElementsPerWI> wi_f32_vals;
1458
+ sycl::vec<int8_t, ElementsPerWI> quantized_values;
1459
+
1460
+ auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1461
+ wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
1462
+
1463
+ float sum = 0.0f;
1464
+ float amax = 0.0f;
1465
+
1466
+ #pragma unroll(ElementsPerWI)
1467
+ for (int i = 0; i < ElementsPerWI; i++) {
1468
+ sum += wi_f32_vals[i];
1469
+ amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
1470
+ quantized_values[i] = 0;
1471
+ }
1472
+ sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
1473
+ amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
1474
+ float d = amax == 0 ? 1 : amax / 127;
1475
+
1476
+ #pragma unroll(ElementsPerWI)
1477
+ for (int i = 0; i < ElementsPerWI; i++) {
1478
+ quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
1479
+ }
1480
+
1481
+ d = amax == 0 ? 0 : d;
1482
+
1483
+ *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
1484
+ if (wi_id == 0) {
1485
+ *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
1486
+ }
1487
+ }
1488
+
1437
1489
  static void mul_mat_p021_f16_f32(
1438
1490
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1439
1491
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1718,23 +1770,30 @@ static void pool2d_nchw_kernel(
1718
1770
  o_ptr[cur_oh * ow + cur_ow] = res;
1719
1771
  }
1720
1772
 
1721
- static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1722
- const int ky, const int kx_padded,
1723
- queue_ptr stream) {
1724
- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1725
- const sycl::range<3> num_blocks(1, ky, block_num_x);
1726
- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1727
- static_assert(QK8_1 % WARP_SIZE == 0);
1728
- const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1729
- {
1730
- dpct::has_capability_or_fail(stream->get_device(),
1731
- {sycl::aspect::fp16});
1773
+ static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1774
+ bool reorder_q8_tensor, queue_ptr stream) {
1775
+ if (reorder_q8_tensor) {
1776
+ auto local_range = std::size_t(WARP_SIZE);
1777
+ auto num_quant_blocks = ky * (kx / QK8_1);
1778
+ auto global_range = num_quant_blocks * local_range;
1779
+ stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
1780
+ [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1781
+ quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1782
+ });
1783
+ } else {
1784
+ const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1785
+ const sycl::range<3> num_blocks(1, ky, block_num_x);
1786
+ int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1787
+ static_assert(QK8_1 % WARP_SIZE == 0);
1788
+ const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1789
+ {
1790
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
1732
1791
 
1733
- stream->parallel_for(
1734
- sycl::nd_range<3>(num_blocks * block_size, block_size),
1735
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1736
- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1737
- });
1792
+ stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
1793
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1794
+ quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1795
+ });
1796
+ }
1738
1797
  }
1739
1798
  }
1740
1799
 
@@ -1826,13 +1885,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1826
1885
  const size_t shared_mem = ncols_pad * sizeof(int);
1827
1886
 
1828
1887
  if (order == GGML_SORT_ORDER_ASC) {
1829
- stream->submit([&](sycl::handler &cgh) {
1888
+ sycl_launch(stream, [&](sycl::handler & cgh) {
1830
1889
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1831
1890
  sycl::range<1>(shared_mem), cgh);
1832
1891
 
1833
- cgh.parallel_for(
1834
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1835
- [=](sycl::nd_item<3> item_ct1) {
1892
+ sycl_parallel_for(
1893
+ cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1836
1894
  k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
1837
1895
  x, dst, ncols, ncols_pad, item_ct1,
1838
1896
  dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -1840,13 +1898,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1840
1898
  });
1841
1899
  });
1842
1900
  } else if (order == GGML_SORT_ORDER_DESC) {
1843
- stream->submit([&](sycl::handler &cgh) {
1901
+ sycl_launch(stream, [&](sycl::handler & cgh) {
1844
1902
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1845
1903
  sycl::range<1>(shared_mem), cgh);
1846
1904
 
1847
- cgh.parallel_for(
1848
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1849
- [=](sycl::nd_item<3> item_ct1) {
1905
+ sycl_parallel_for(
1906
+ cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1850
1907
  k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
1851
1908
  x, dst, ncols, ncols_pad, item_ct1,
1852
1909
  dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -1864,50 +1921,47 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
1864
1921
  const sycl::range<3> block_nums(1, nrows, 1);
1865
1922
  const size_t shared_mem = 256 * sizeof(float);
1866
1923
 
1867
- stream->submit([&](sycl::handler &cgh) {
1924
+ sycl_launch(stream, [&](sycl::handler & cgh) {
1868
1925
  sycl::local_accessor<float, 1> shared_data(
1869
1926
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
1870
1927
  sycl::local_accessor<int, 1> shared_indices(
1871
1928
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
1872
1929
 
1873
- cgh.parallel_for(
1874
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1875
- [=](sycl::nd_item<3> item_ct1) {
1876
- const int tid = item_ct1.get_local_id(2);
1877
- const int row = item_ct1.get_global_id(1);
1878
-
1879
- float max_val = -INFINITY;
1880
- int max_idx = -1;
1881
-
1882
- for (int col = tid; col < ncols; col += 256) {
1883
- float val = x[row * ncols + col];
1884
- if (val > max_val) {
1885
- max_val = val;
1886
- max_idx = col;
1887
- }
1888
- }
1930
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1931
+ const int tid = item_ct1.get_local_id(2);
1932
+ const int row = item_ct1.get_global_id(1);
1889
1933
 
1890
- shared_data[tid] = max_val;
1891
- shared_indices[tid] = max_idx;
1892
- item_ct1.barrier(sycl::access::fence_space::local_space);
1934
+ float max_val = -INFINITY;
1935
+ int max_idx = -1;
1893
1936
 
1894
- for (int stride = 256/2; stride > 0; stride >>= 1) {
1895
- if (tid < stride) {
1896
- float val1 = shared_data[tid];
1897
- float val2 = shared_data[tid + stride];
1898
- if (val2 > val1) {
1899
- shared_data[tid] = val2;
1900
- shared_indices[tid] = shared_indices[tid + stride];
1901
- }
1902
- }
1903
- item_ct1.barrier(sycl::access::fence_space::local_space);
1937
+ for (int col = tid; col < ncols; col += 256) {
1938
+ float val = x[row * ncols + col];
1939
+ if (val > max_val) {
1940
+ max_val = val;
1941
+ max_idx = col;
1904
1942
  }
1943
+ }
1905
1944
 
1945
+ shared_data[tid] = max_val;
1946
+ shared_indices[tid] = max_idx;
1947
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1906
1948
 
1907
- if (tid == 0) {
1908
- dst[row] = shared_indices[0];
1949
+ for (int stride = 256 / 2; stride > 0; stride >>= 1) {
1950
+ if (tid < stride) {
1951
+ float val1 = shared_data[tid];
1952
+ float val2 = shared_data[tid + stride];
1953
+ if (val2 > val1) {
1954
+ shared_data[tid] = val2;
1955
+ shared_indices[tid] = shared_indices[tid + stride];
1956
+ }
1909
1957
  }
1910
- });
1958
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1959
+ }
1960
+
1961
+ if (tid == 0) {
1962
+ dst[row] = shared_indices[0];
1963
+ }
1964
+ });
1911
1965
  });
1912
1966
  }
1913
1967
  static void diag_mask_inf_f32_sycl(const float *x, float *dst,
@@ -2066,21 +2120,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
2066
2120
  const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
2067
2121
  ? (const sycl::half *)src1->data + src1_padded_row_size
2068
2122
  : src1_as_f16.get();
2069
- ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2070
2123
 
2071
2124
  #if GGML_SYCL_DNNL
2072
2125
  if (!g_ggml_sycl_disable_dnn) {
2073
2126
  DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2074
2127
  DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2075
- dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2076
- scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2077
- " : converting dst to fp32");
2078
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2079
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2128
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2080
2129
  }
2081
2130
  else
2082
2131
  #endif
2083
2132
  {
2133
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2134
+
2084
2135
  const sycl::half alpha_f16 = 1.0f;
2085
2136
  const sycl::half beta_f16 = 0.0f;
2086
2137
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
@@ -2446,9 +2497,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2446
2497
  dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2447
2498
 
2448
2499
  if (src1_on_device && src1_is_contiguous) {
2500
+ bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
2449
2501
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2450
2502
  /*num_src=*/2, " : converting src1 to Q8_1");
2451
- quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2503
+ quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2452
2504
  /*
2453
2505
  DPCT1010:90: SYCL uses exceptions to report errors and does not
2454
2506
  use the error codes. The call was replaced with 0. You need to
@@ -2554,7 +2606,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2554
2606
  if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2555
2607
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2556
2608
  /*num_src=*/2, " : converting src1 to Q8_1");
2557
- quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
2609
+ quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
2558
2610
  /*
2559
2611
  DPCT1010:92: SYCL uses exceptions to report errors and does
2560
2612
  not use the error codes. The call was replaced with 0. You
@@ -2893,7 +2945,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2893
2945
  void ** ptrs_dst_get = ptrs_dst.get();
2894
2946
  size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2895
2947
  size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2896
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2948
+ sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2897
2949
  k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2898
2950
  nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2899
2951
  });
@@ -2928,6 +2980,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
2928
2980
  case GGML_TYPE_Q4_0:
2929
2981
  return true;
2930
2982
  case GGML_TYPE_Q4_K:
2983
+ case GGML_TYPE_Q6_K:
2931
2984
  return !g_ggml_sycl_prioritize_dmmv;
2932
2985
  default:
2933
2986
  return false;
@@ -2947,6 +3000,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2947
3000
  switch (type) {
2948
3001
  case GGML_TYPE_Q4_0:
2949
3002
  case GGML_TYPE_Q4_K:
3003
+ case GGML_TYPE_Q6_K:
2950
3004
  return true;
2951
3005
  default:
2952
3006
  return false;
@@ -3031,6 +3085,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
3031
3085
  sycl::free(tmp_buf, *stream);
3032
3086
  }
3033
3087
 
3088
+ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3089
+ GGML_ASSERT(size % sizeof(block_q6_K) == 0);
3090
+ GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
3091
+
3092
+ const int nblocks = size / sizeof(block_q6_K);
3093
+
3094
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
3095
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3096
+
3097
+ auto * ql_ptr = data_device;
3098
+ auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
3099
+ auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
3100
+ sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
3101
+
3102
+ stream
3103
+ ->parallel_for(nblocks,
3104
+ [=](auto i) {
3105
+ const block_q6_K * x = (const block_q6_K *) tmp_buf;
3106
+ const int ib = i;
3107
+
3108
+ const uint8_t * ql = x[ib].ql;
3109
+ const uint8_t * qh = x[ib].qh;
3110
+ uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
3111
+ uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
3112
+ uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
3113
+
3114
+ for (int j = 0; j < QK_K / 2; ++j) {
3115
+ base_ql_ptr[j] = ql[j];
3116
+ }
3117
+ for (int j = 0; j < QK_K / 4; ++j) {
3118
+ base_qh_ptr[j] = qh[j];
3119
+ }
3120
+
3121
+ for (int j = 0; j < QK_K / 16; ++j) {
3122
+ base_scales_ptr[j] = x[ib].scales[j];
3123
+ }
3124
+
3125
+ dm_ptr[ib] = x[ib].d;
3126
+ })
3127
+ .wait_and_throw();
3128
+
3129
+ sycl::free(tmp_buf, *stream);
3130
+ }
3131
+
3034
3132
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3035
3133
  uint8_t * data_device = (uint8_t *) src0->data;
3036
3134
  size_t ncols = src0->ne[0];
@@ -3044,6 +3142,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3044
3142
  case GGML_TYPE_Q4_K:
3045
3143
  reorder_qw_q4_k(data_device, size, 0, stream);
3046
3144
  break;
3145
+ case GGML_TYPE_Q6_K:
3146
+ reorder_qw_q6_k(data_device, size, 0, stream);
3147
+ break;
3047
3148
  default:
3048
3149
  GGML_ABORT("reorder_qw() called with unsupported type");
3049
3150
  break;
@@ -3348,7 +3449,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3348
3449
  {
3349
3450
  sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
3350
3451
  sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
3351
- stream->submit([&](sycl::handler &cgh) {
3452
+ sycl_launch(stream, [&](sycl::handler & cgh) {
3352
3453
  sycl::local_accessor<int, 0> src1_row_acc(cgh);
3353
3454
 
3354
3455
  char *__restrict src1_contiguous_get =
@@ -3360,9 +3461,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3360
3461
  size_t ids_nb_ct6 = ids->nb[1];
3361
3462
  size_t ids_nb_ct7 = ids->nb[0];
3362
3463
 
3363
- cgh.parallel_for(
3364
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3365
- [=](sycl::nd_item<3> item_ct1) {
3464
+ sycl_parallel_for(
3465
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3366
3466
  k_copy_src1_to_contiguous(
3367
3467
  src1_original, src1_contiguous_get,
3368
3468
  dev_cur_src1_row_get,
@@ -3393,15 +3493,14 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3393
3493
  {
3394
3494
  sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
3395
3495
  sycl::range<3> grid_dims(1, 1, num_src1_rows);
3396
- stream->submit([&](sycl::handler &cgh) {
3496
+ sycl_launch(stream, [&](sycl::handler & cgh) {
3397
3497
  const char *__restrict dst_contiguous_get =
3398
3498
  dst_contiguous.get();
3399
3499
  const mmid_row_mapping *__restrict dev_row_mapping_get =
3400
3500
  dev_row_mapping.get();
3401
3501
 
3402
- cgh.parallel_for(
3403
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3404
- [=](sycl::nd_item<3> item_ct1) {
3502
+ sycl_parallel_for(
3503
+ cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3405
3504
  k_copy_dst_from_contiguous(dst_original,
3406
3505
  dst_contiguous_get,
3407
3506
  dev_row_mapping_get,
@@ -3543,6 +3642,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3543
3642
  case GGML_UNARY_OP_GELU_QUICK:
3544
3643
  ggml_sycl_gelu_quick(ctx, dst);
3545
3644
  break;
3645
+ case GGML_UNARY_OP_GELU_ERF:
3646
+ ggml_sycl_gelu_erf(ctx, dst);
3647
+ break;
3546
3648
  case GGML_UNARY_OP_TANH:
3547
3649
  ggml_sycl_tanh(ctx, dst);
3548
3650
  break;
@@ -3574,6 +3676,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3574
3676
  return false;
3575
3677
  }
3576
3678
  break;
3679
+ case GGML_OP_GLU:
3680
+ switch (ggml_get_glu_op(dst)) {
3681
+ case GGML_GLU_OP_REGLU:
3682
+ ggml_sycl_reglu(ctx, dst);
3683
+ break;
3684
+ case GGML_GLU_OP_GEGLU:
3685
+ ggml_sycl_geglu(ctx, dst);
3686
+ break;
3687
+ case GGML_GLU_OP_SWIGLU:
3688
+ ggml_sycl_swiglu(ctx, dst);
3689
+ break;
3690
+ default:
3691
+ return false;
3692
+ }
3693
+ break;
3577
3694
  case GGML_OP_NORM:
3578
3695
  ggml_sycl_norm(ctx, dst);
3579
3696
  break;
@@ -3752,7 +3869,7 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
3752
3869
  const void *data, size_t offset,
3753
3870
  size_t size) try {
3754
3871
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3755
- debug_print_tensor(": tensor=", tensor);
3872
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
3756
3873
  GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
3757
3874
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
3758
3875
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
@@ -3773,7 +3890,7 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
3773
3890
  void *data, size_t offset,
3774
3891
  size_t size) try {
3775
3892
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3776
- debug_print_tensor(": tensor=", tensor);
3893
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
3777
3894
  GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
3778
3895
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
3779
3896
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
@@ -3796,8 +3913,8 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
3796
3913
  bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&
3797
3914
  ggml_backend_buffer_is_sycl(src->buffer);
3798
3915
  GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
3799
- debug_print_tensor(": dst=", dst);
3800
- debug_print_tensor(" src=", src);
3916
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
3917
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
3801
3918
  GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
3802
3919
  if (is_cpy_supported) {
3803
3920
  /*
@@ -4096,6 +4213,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4096
4213
  case GGML_UNARY_OP_HARDSIGMOID:
4097
4214
  case GGML_UNARY_OP_HARDSWISH:
4098
4215
  case GGML_UNARY_OP_GELU_QUICK:
4216
+ case GGML_UNARY_OP_GELU_ERF:
4099
4217
  case GGML_UNARY_OP_TANH:
4100
4218
  case GGML_UNARY_OP_EXP:
4101
4219
  case GGML_UNARY_OP_SGN:
@@ -4109,6 +4227,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4109
4227
  default:
4110
4228
  return false;
4111
4229
  }
4230
+ case GGML_OP_GLU:
4231
+ switch (ggml_get_glu_op(op)) {
4232
+ case GGML_GLU_OP_REGLU:
4233
+ case GGML_GLU_OP_GEGLU:
4234
+ case GGML_GLU_OP_SWIGLU:
4235
+ return ggml_is_contiguous_1(op->src[0]);
4236
+ default:
4237
+ return false;
4238
+ }
4239
+ break;
4112
4240
  case GGML_OP_MUL_MAT:
4113
4241
  case GGML_OP_MUL_MAT_ID:
4114
4242
  {
@@ -4161,6 +4289,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4161
4289
  {
4162
4290
  ggml_type src0_type = op->src[0]->type;
4163
4291
  ggml_type src1_type = op->src[1]->type;
4292
+ if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
4293
+ return true;
4294
+ }
4164
4295
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4165
4296
  return true;
4166
4297
  }
@@ -4206,6 +4337,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4206
4337
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
4207
4338
  return true;
4208
4339
  }
4340
+ if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
4341
+ return true;
4342
+ }
4343
+ if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
4344
+ return true;
4345
+ }
4346
+ if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
4347
+ return true;
4348
+ }
4349
+ if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
4350
+ return true;
4351
+ }
4352
+ if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
4353
+ return true;
4354
+ }
4209
4355
  return false;
4210
4356
  }
4211
4357
  case GGML_OP_CONCAT:
@@ -4253,14 +4399,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4253
4399
  case GGML_OP_SOFT_MAX:
4254
4400
  return true;
4255
4401
  case GGML_OP_ROPE:
4256
- {
4257
- const int mode = ((const int32_t *) op->op_params)[2];
4258
- // mode is not used as a bitmask in practice, the various rope type modes are independent implementations
4259
- if (mode == GGML_ROPE_TYPE_MROPE) {
4260
- return false;
4261
- }
4262
- return true;
4263
- }
4264
4402
  case GGML_OP_IM2COL:
4265
4403
  return true;
4266
4404
  case GGML_OP_UPSCALE:
@@ -11,13 +11,13 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
11
11
  const u_int n_seq_tokens = T / B;
12
12
  sycl::range<1> block_dims((C / H));
13
13
  sycl::range<1> grid_dims((B * H));
14
- stream->submit([&](sycl::handler & cgh) {
14
+ sycl_launch(stream, [&](sycl::handler & cgh) {
15
15
  /* local memory accessors*/
16
16
  auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
17
17
  auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
18
18
  auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
19
19
 
20
- cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
20
+ sycl_parallel_for<1>(cgh, sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
21
21
  u_int tid = item.get_local_id(0);
22
22
  u_int bid = item.get_group(0);
23
23
 
@@ -70,7 +70,7 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
70
70
 
71
71
  const int64_t CHW = IC * KH * KW;
72
72
 
73
- stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
73
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
74
74
  im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
75
75
  p0, p1, d0, d1, item_ct1);
76
76
  });