whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,8 +1,12 @@
1
1
  #include "cpy.hpp"
2
2
 
3
3
  #include <float.h>
4
+ #include <string>
4
5
 
5
6
  #include "dequantize.hpp"
7
+ #include "ggml-sycl/common.hpp"
8
+ #include "ggml-sycl/presets.hpp"
9
+ #include "ggml.h"
6
10
 
7
11
  static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
8
12
  if (x <= val[0]) {
@@ -116,6 +120,15 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
116
120
  }
117
121
  }
118
122
 
123
+ /* quantized type same copy */
124
+ template<typename T>
125
+ static void cpy_blck_q_q(const char * cxi, char * cdsti) {
126
+ const T * xi = (const T *) cxi;
127
+ T * dsti = (T *) cdsti;
128
+ *dsti = *xi;
129
+ }
130
+
131
+
119
132
  static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
120
133
  float * cdstf = (float *) (cdsti);
121
134
 
@@ -311,6 +324,34 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
311
324
  }
312
325
  }
313
326
 
327
+
328
+ template <typename T, int qk>
329
+ static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
330
+ const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
331
+ const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
332
+ const sycl::nd_item<3> & item_ct1) {
333
+ const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
334
+
335
+ if (i >= ne) {
336
+ return;
337
+ }
338
+
339
+ const int i03 = i / (ne00 * ne01 * ne02);
340
+ const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
341
+ const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
342
+ const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
343
+ const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
344
+
345
+
346
+ const int i13 = i / (ne10 * ne11 * ne12);
347
+ const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
348
+ const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
349
+ const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
350
+ const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
351
+
352
+ cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
353
+ }
354
+
314
355
  template <cpy_kernel_t cpy_blck, int qk>
315
356
  static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
316
357
  const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
@@ -322,6 +363,7 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00
322
363
  return;
323
364
  }
324
365
 
366
+
325
367
  const int i03 = i / (ne00 * ne01 * ne02);
326
368
  const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
327
369
  const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
@@ -371,7 +413,8 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
371
413
  {
372
414
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
373
415
 
374
- stream->parallel_for(
416
+ sycl_parallel_for(
417
+ stream,
375
418
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
376
419
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
377
420
  [=](sycl::nd_item<3> item_ct1) {
@@ -389,7 +432,8 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
389
432
  {
390
433
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
391
434
 
392
- stream->parallel_for(
435
+ sycl_parallel_for(
436
+ stream,
393
437
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
394
438
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
395
439
  [=](sycl::nd_item<3> item_ct1) {
@@ -407,7 +451,8 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
407
451
  {
408
452
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
409
453
 
410
- stream->parallel_for(
454
+ sycl_parallel_for(
455
+ stream,
411
456
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
412
457
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
413
458
  [=](sycl::nd_item<3> item_ct1) {
@@ -423,11 +468,11 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
423
468
  const int nb12, const int nb13, queue_ptr stream) {
424
469
  GGML_ASSERT(ne % QK8_0 == 0);
425
470
  const int num_blocks = ne / QK8_0;
426
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
427
- [=](sycl::nd_item<3> item_ct1) {
428
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
429
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
430
- });
471
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
472
+ [=](sycl::nd_item<3> item_ct1) {
473
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
474
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
475
+ });
431
476
  }
432
477
 
433
478
  static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -435,11 +480,11 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
435
480
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
436
481
  const int nb12, const int nb13, queue_ptr stream) {
437
482
  const int num_blocks = ne;
438
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
439
- [=](sycl::nd_item<3> item_ct1) {
440
- cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
441
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
442
- });
483
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
484
+ [=](sycl::nd_item<3> item_ct1) {
485
+ cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
486
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
487
+ });
443
488
  }
444
489
 
445
490
  static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -448,11 +493,11 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
448
493
  const int nb12, const int nb13, queue_ptr stream) {
449
494
  GGML_ASSERT(ne % QK4_0 == 0);
450
495
  const int num_blocks = ne / QK4_0;
451
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
452
- [=](sycl::nd_item<3> item_ct1) {
453
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
454
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
455
- });
496
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
497
+ [=](sycl::nd_item<3> item_ct1) {
498
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
499
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
500
+ });
456
501
  }
457
502
 
458
503
  static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -460,8 +505,9 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
460
505
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
461
506
  const int nb12, const int nb13, queue_ptr stream) {
462
507
  const int num_blocks = ne;
463
- stream->parallel_for(
464
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
508
+ sycl_parallel_for(
509
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
510
+ [=](sycl::nd_item<3> item_ct1) {
465
511
  cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
466
512
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
467
513
  item_ct1);
@@ -474,11 +520,11 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
474
520
  const int nb12, const int nb13, queue_ptr stream) {
475
521
  GGML_ASSERT(ne % QK4_1 == 0);
476
522
  const int num_blocks = ne / QK4_1;
477
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
478
- [=](sycl::nd_item<3> item_ct1) {
479
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
480
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
481
- });
523
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
524
+ [=](sycl::nd_item<3> item_ct1) {
525
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
526
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
527
+ });
482
528
  }
483
529
 
484
530
  static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -486,8 +532,9 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
486
532
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
487
533
  const int nb12, const int nb13, queue_ptr stream) {
488
534
  const int num_blocks = ne;
489
- stream->parallel_for(
490
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
535
+ sycl_parallel_for(
536
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
537
+ [=](sycl::nd_item<3> item_ct1) {
491
538
  cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
492
539
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
493
540
  item_ct1);
@@ -500,11 +547,11 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
500
547
  const int nb12, const int nb13, queue_ptr stream) {
501
548
  GGML_ASSERT(ne % QK5_0 == 0);
502
549
  const int num_blocks = ne / QK5_0;
503
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
504
- [=](sycl::nd_item<3> item_ct1) {
505
- cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
506
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
507
- });
550
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
551
+ [=](sycl::nd_item<3> item_ct1) {
552
+ cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
553
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
554
+ });
508
555
  }
509
556
 
510
557
  static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -512,8 +559,9 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
512
559
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
513
560
  const int nb12, const int nb13, queue_ptr stream) {
514
561
  const int num_blocks = ne;
515
- stream->parallel_for(
516
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
562
+ sycl_parallel_for(
563
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
564
+ [=](sycl::nd_item<3> item_ct1) {
517
565
  cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
518
566
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
519
567
  item_ct1);
@@ -526,11 +574,11 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
526
574
  const int nb12, const int nb13, queue_ptr stream) {
527
575
  GGML_ASSERT(ne % QK5_1 == 0);
528
576
  const int num_blocks = ne / QK5_1;
529
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
530
- [=](sycl::nd_item<3> item_ct1) {
531
- cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
532
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
533
- });
577
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
578
+ [=](sycl::nd_item<3> item_ct1) {
579
+ cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
580
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
581
+ });
534
582
  }
535
583
 
536
584
  static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -538,8 +586,9 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
538
586
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
539
587
  const int nb12, const int nb13, queue_ptr stream) {
540
588
  const int num_blocks = ne;
541
- stream->parallel_for(
542
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
589
+ sycl_parallel_for(
590
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
591
+ [=](sycl::nd_item<3> item_ct1) {
543
592
  cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
544
593
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
545
594
  item_ct1);
@@ -552,11 +601,11 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
552
601
  const int nb12, const int nb13, queue_ptr stream) {
553
602
  GGML_ASSERT(ne % QK4_NL == 0);
554
603
  const int num_blocks = ne / QK4_NL;
555
- stream->parallel_for(
556
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
557
- cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
558
- ne12, nb10, nb11, nb12, nb13, item_ct1);
559
- });
604
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
605
+ [=](sycl::nd_item<3> item_ct1) {
606
+ cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
607
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
608
+ });
560
609
  }
561
610
 
562
611
  static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -567,7 +616,8 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
567
616
  {
568
617
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
569
618
 
570
- stream->parallel_for(
619
+ sycl_parallel_for(
620
+ stream,
571
621
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
572
622
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
573
623
  [=](sycl::nd_item<3> item_ct1) {
@@ -586,7 +636,8 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
586
636
  // dpct::has_capability_or_fail(stream->get_device(),
587
637
  // {sycl::aspect::fp16});
588
638
 
589
- stream->parallel_for(
639
+ sycl_parallel_for(
640
+ stream,
590
641
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
591
642
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
592
643
  [=](sycl::nd_item<3> item_ct1) {
@@ -605,7 +656,8 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
605
656
  // dpct::has_capability_or_fail(stream->get_device(),
606
657
  // {sycl::aspect::fp16});
607
658
 
608
- stream->parallel_for(
659
+ sycl_parallel_for(
660
+ stream,
609
661
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
610
662
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
611
663
  [=](sycl::nd_item<3> item_ct1) {
@@ -615,10 +667,85 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
615
667
  }
616
668
  }
617
669
 
670
+ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
671
+ const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
672
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
673
+ const int nb12, const int nb13, queue_ptr stream) {
674
+ const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
675
+ sycl_parallel_for(stream,
676
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
677
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
678
+ [=](sycl::nd_item<3> item_ct1) {
679
+ cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
680
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
681
+ });
682
+ }
683
+
684
+
685
+ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
686
+ const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
687
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
688
+ const int nb12, const int nb13, queue_ptr stream) {
689
+ const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
690
+ sycl_parallel_for(stream,
691
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
692
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
693
+ [=](sycl::nd_item<3> item_ct1) {
694
+ cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
695
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
696
+ });
697
+ }
698
+
699
+
700
+ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
701
+ const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
702
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
703
+ const int nb12, const int nb13, queue_ptr stream) {
704
+ const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
705
+
706
+ sycl_parallel_for(stream,
707
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
708
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
709
+ [=](sycl::nd_item<3> item_ct1) {
710
+ cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
711
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
712
+ });
713
+ }
714
+
715
+
716
+ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
717
+ const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
718
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
719
+ const int nb12, const int nb13, queue_ptr stream) {
720
+ const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
721
+ sycl_parallel_for(stream,
722
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
723
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
724
+ [=](sycl::nd_item<3> item_ct1) {
725
+ cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
726
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
727
+ });
728
+ }
729
+
730
+
731
+ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
732
+ const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
733
+ const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
734
+ const int nb12, const int nb13, queue_ptr stream) {
735
+
736
+ const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
737
+ sycl_parallel_for(stream,
738
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
739
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
740
+ [=](sycl::nd_item<3> item_ct1) {
741
+ cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
742
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
743
+ });
744
+ }
745
+
618
746
  void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
619
747
  // Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field
620
- scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0,
621
- std::string(" src0 type=") + ggml_type_name(src0->type));
748
+ scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, debug_get_tensor_str("\tsrc0", src0));
622
749
  const int64_t ne = ggml_nelements(src0);
623
750
  GGML_ASSERT(ne == ggml_nelements(src1));
624
751
 
@@ -632,8 +759,10 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
632
759
 
633
760
  char * src0_ddc = (char *) src0->data;
634
761
  char * src1_ddc = (char *) src1->data;
635
-
636
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
762
+ if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) {
763
+ GGML_SYCL_DEBUG("%s: memcpy path\n", __func__);
764
+ main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0));
765
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
637
766
  ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
638
767
  nb11, nb12, nb13, main_stream);
639
768
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
@@ -684,6 +813,16 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
684
813
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
685
814
  ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
686
815
  nb10, nb11, nb12, nb13, main_stream);
816
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
817
+ ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
818
+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) {
819
+ ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
820
+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) {
821
+ ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
822
+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) {
823
+ ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
824
+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) {
825
+ ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
687
826
  } else {
688
827
  GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
689
828
  ggml_type_name(src1->type));
@@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
538
538
  #endif
539
539
  }
540
540
 
541
+ template <typename dst_t>
542
+ static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
543
+ const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
544
+ const int64_t ib = item_ct1.get_group(2);
545
+
546
+ const int64_t tid = item_ct1.get_local_id(2);
547
+ const int64_t ip = tid / 32; // ip is 0 or 1
548
+ const int64_t il = tid - 32 * ip; // 0...32
549
+ const int64_t is = 8 * ip + il / 16;
550
+
551
+ const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
552
+ const auto ql_offset = ib * (QK_K / 2);
553
+ const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
554
+ const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
555
+ const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
556
+ const uint8_t * ql_ptr = base_ptr + ql_offset;
557
+ const uint8_t * qh_ptr = base_ptr + qh_offset;
558
+ const uint8_t * scales_ptr = base_ptr + base_scales_offset;
559
+ const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
560
+
561
+ dst_t * y = yy + ib * QK_K + 128 * ip + il;
562
+
563
+ const uint8_t * ql = ql_ptr + 64 * ip + il;
564
+ const uint8_t qh = *(qh_ptr + 32 * ip + il);
565
+ const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
566
+
567
+ y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
568
+ y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
569
+ y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
570
+ y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
571
+ }
572
+
541
573
  template<typename dst_t>
542
574
  static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
543
575
  const sycl::nd_item<3> &item_ct1,