whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -254,14 +254,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
254
254
  GGML_ASSERT(ncols % WARP_SIZE == 0);
255
255
  if (ncols < 1024) {
256
256
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
257
- stream->submit([&](sycl::handler& cgh) {
258
- cgh.parallel_for(
259
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
260
- [=](sycl::nd_item<3> item_ct1)
261
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
262
- norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
263
- });
264
- });
257
+ sycl_launch(stream, [&](sycl::handler & cgh) {
258
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
259
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
260
+ norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
261
+ nullptr, WARP_SIZE);
262
+ });
263
+ });
265
264
  }
266
265
  else {
267
266
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -272,16 +271,15 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
272
271
  the limit. To get the device limit, query
273
272
  info::device::max_work_group_size. Adjust the work-group size if needed.
274
273
  */
275
- stream->submit([&](sycl::handler& cgh) {
274
+ sycl_launch(stream, [&](sycl::handler & cgh) {
276
275
  sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
277
276
  sycl::range<1>(work_group_size / WARP_SIZE), cgh);
278
- cgh.parallel_for(
279
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
280
- [=](sycl::nd_item<3> item_ct1)
281
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
282
- norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
283
- });
284
- });
277
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
278
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
279
+ norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
280
+ get_pointer(s_sum_acc_ct1), work_group_size);
281
+ });
282
+ });
285
283
  }
286
284
  }
287
285
 
@@ -290,18 +288,14 @@ static void group_norm_f32_sycl(const float* x, float* dst,
290
288
  const int ne_elements, queue_ptr stream, int device) {
291
289
  if (group_size < 1024) {
292
290
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
293
- stream->submit([&](sycl::handler& cgh) {
291
+ sycl_launch(stream, [&](sycl::handler & cgh) {
294
292
  const float eps_ct4 = eps;
295
- cgh.parallel_for(
296
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
297
- block_dims),
298
- [=](sycl::nd_item<3> item_ct1)
299
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
300
- group_norm_f32(
301
- x, dst, group_size, ne_elements, eps_ct4, item_ct1,
302
- nullptr, WARP_SIZE);
303
- });
304
- });
293
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
294
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
295
+ group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
296
+ WARP_SIZE);
297
+ });
298
+ });
305
299
  }
306
300
  else {
307
301
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -313,22 +307,18 @@ static void group_norm_f32_sycl(const float* x, float* dst,
313
307
  info::device::max_work_group_size. Adjust the work-group size if needed.
314
308
  */
315
309
 
316
- stream->submit([&](sycl::handler& cgh) {
310
+ sycl_launch(stream, [&](sycl::handler & cgh) {
317
311
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
318
312
  cgh);
319
313
 
320
314
  const float eps_ct4 = eps;
321
315
 
322
- cgh.parallel_for(
323
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
324
- block_dims),
325
- [=](sycl::nd_item<3> item_ct1)
326
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
327
- group_norm_f32(x, dst, group_size, ne_elements,
328
- eps_ct4, item_ct1,
329
- get_pointer(s_sum_acc_ct1), work_group_size);
330
- });
331
- });
316
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
317
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
318
+ group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
319
+ get_pointer(s_sum_acc_ct1), work_group_size);
320
+ });
321
+ });
332
322
  }
333
323
  }
334
324
 
@@ -340,14 +330,13 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
340
330
  const sycl::range<3> global_dims(nsamples, nchannels, nrows);
341
331
  if (ncols < 1024) {
342
332
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
343
- stream->submit([&](sycl::handler& cgh) {
344
- cgh.parallel_for(
345
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
346
- [=](sycl::nd_item<3> item_ct1)
347
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
348
- rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
349
- });
350
- });
333
+ sycl_launch(stream, [&](sycl::handler & cgh) {
334
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
335
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
336
+ rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
337
+ nullptr, WARP_SIZE);
338
+ });
339
+ });
351
340
  }
352
341
  else {
353
342
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -358,16 +347,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
358
347
  the limit. To get the device limit, query
359
348
  info::device::max_work_group_size. Adjust the work-group size if needed.
360
349
  */
361
- stream->submit([&](sycl::handler& cgh) {
350
+ sycl_launch(stream, [&](sycl::handler & cgh) {
362
351
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
363
352
  cgh);
364
- cgh.parallel_for(
365
- sycl::nd_range<3>(global_dims * block_dims, block_dims),
366
- [=](sycl::nd_item<3> item_ct1)
367
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
368
- rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
369
- });
370
- });
353
+ sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
354
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
355
+ rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
356
+ get_pointer(s_sum_acc_ct1), work_group_size);
357
+ });
358
+ });
371
359
  }
372
360
  }
373
361
 
@@ -378,16 +366,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
378
366
  // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
379
367
  if (ncols < 1024) {
380
368
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
381
- stream->submit([&](sycl::handler& cgh) {
382
- cgh.parallel_for(
383
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
384
- block_dims),
385
- [=](sycl::nd_item<3> item_ct1)
386
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
387
- l2_norm_f32(x, dst, ncols, eps, item_ct1,
388
- nullptr, WARP_SIZE);
389
- });
390
- });
369
+ sycl_launch(stream, [&](sycl::handler & cgh) {
370
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
371
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
372
+ l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
373
+ });
374
+ });
391
375
  }
392
376
  else {
393
377
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -398,18 +382,15 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
398
382
  the limit. To get the device limit, query
399
383
  info::device::max_work_group_size. Adjust the work-group size if needed.
400
384
  */
401
- stream->submit([&](sycl::handler& cgh) {
385
+ sycl_launch(stream, [&](sycl::handler & cgh) {
402
386
  sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
403
387
  cgh);
404
- cgh.parallel_for(
405
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
406
- block_dims),
407
- [=](sycl::nd_item<3> item_ct1)
408
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
409
- l2_norm_f32(x, dst, ncols, eps, item_ct1,
410
- get_pointer(s_sum_acc_ct1), work_group_size);
411
- });
412
- });
388
+ sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
389
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
390
+ l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
391
+ work_group_size);
392
+ });
393
+ });
413
394
  }
414
395
  }
415
396
 
@@ -14,12 +14,13 @@
14
14
  #ifndef GGML_SYCL_QUANTS_HPP
15
15
  #define GGML_SYCL_QUANTS_HPP
16
16
 
17
+ #include <utility>
18
+
17
19
  #include "ggml-common.h"
18
20
  #include "ggml.h"
19
21
 
20
22
  namespace ggml_sycl_reordered {
21
23
 
22
-
23
24
  // The reordered block moves quants (qs) and scales(d) to two
24
25
  // uniform regions of memory that is contiguous in the same tensor.
25
26
  // What this means is that instead of having:
@@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
32
33
 
33
34
  template <ggml_type type> struct block_q_t;
34
35
 
35
-
36
36
  // qk number of weights / quants in a block
37
37
  // qr number of weights in a byte (described as 'before dequantization')
38
38
  // for quantization types that has low and high bits split, qr is calculated with
@@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
47
47
  static constexpr uint32_t vdr_mmvq = 2;
48
48
  };
49
49
 
50
- static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
50
+ static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
51
+ return { block_index * (traits::qk / traits::qr), 0 };
52
+ }
51
53
 
52
- static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
53
- return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
54
+ static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
55
+ return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
54
56
  }
55
57
 
56
58
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
@@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
64
66
  static constexpr uint32_t vdr_mmvq = 2;
65
67
  };
66
68
 
67
- static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
69
+ static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
70
+ return { block_index * (traits::qk / traits::qr), 0 };
71
+ }
68
72
 
69
- static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
73
+ static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
70
74
  auto nblocks = (nrows * (ncols / traits::qk));
71
- return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
75
+ return { nblocks * (QK_K / 2),
76
+ (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
72
77
  }
73
78
 
74
79
  static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
75
80
 
76
81
  constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
77
-
78
- constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
79
82
  };
80
83
 
84
+ template <> struct block_q_t<GGML_TYPE_Q6_K> {
85
+ struct traits {
86
+ static constexpr uint32_t qk = QK_K;
87
+ static constexpr uint32_t qi = QI6_K;
88
+ static constexpr uint32_t qr = QR6_K;
89
+ static constexpr uint32_t vdr_mmvq = 1;
90
+ };
91
+
92
+ static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
93
+ auto low_bits_index = block_index * (traits::qk / traits::qr);
94
+ // the index of high bits it's after all low bits
95
+ auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
96
+ return { low_bits_index, high_bits_index };
97
+ }
98
+
99
+ static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
100
+ auto nblocks = (nrows * (ncols / traits::qk));
101
+ auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
102
+ auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
103
+ auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
104
+ return { block_scales, sb_scale };
105
+ }
106
+
107
+ static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
108
+ };
81
109
  } // namespace ggml_sycl_reordered
82
110
 
83
111
  #endif // GGML_SYCL_QUANTS_HPP
@@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
49
49
 
50
50
  if (i0 >= n_dims) {
51
51
  const int i = row * ne0 + i0;
52
-
53
- dst[i + 0] = x[i + 0];
54
- dst[i + 1] = x[i + 1];
55
-
52
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
56
53
  return;
57
54
  }
58
55
 
@@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
93
90
 
94
91
  if (i0 >= n_dims) {
95
92
  const int i = row * ne0 + i0;
96
-
97
- dst[i + 0] = x[i + 0];
98
- dst[i + 1] = x[i + 1];
99
-
93
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
100
94
  return;
101
95
  }
102
96
 
@@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
122
116
  dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
123
117
  }
124
118
 
119
+ template <typename T, bool has_ff>
120
+ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
121
+ const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
122
+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
123
+ const float theta_scale, const float * freq_factors, const mrope_sections sections,
124
+ const sycl::nd_item<3> & item_ct1) {
125
+ // get index pos
126
+ const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
127
+ if (i0 >= ne0) {
128
+ return;
129
+ }
130
+ const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
131
+
132
+ if (i0 >= n_dims) {
133
+ const int i = row_dst*ne0 + i0;
134
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
135
+ return;
136
+ }
137
+
138
+ const int row_x = row_dst % ne1;
139
+ const int channel_x = row_dst / ne1;
140
+ const int idst = (row_dst * ne0) + (i0 / 2);
141
+ const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
142
+
143
+ const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
144
+ const int sec_w = sections.v[1] + sections.v[0];
145
+ const int sector = (i0 / 2) % sect_dims;
146
+
147
+
148
+ float theta_base = 0.0;
149
+ if (sector < sections.v[0]) {
150
+ theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
151
+ }
152
+ else if (sector >= sections.v[0] && sector < sec_w) {
153
+ theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
154
+ }
155
+ else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
156
+ theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
157
+ }
158
+ else if (sector >= sec_w + sections.v[2]) {
159
+ theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
160
+ }
161
+
162
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
163
+ float cos_theta;
164
+ float sin_theta;
165
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
166
+ const float x0 = x[ix + 0];
167
+ const float x1 = x[ix + n_dims/2];
168
+
169
+ // store results in dst
170
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
171
+ dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
172
+ }
173
+
174
+
175
+
125
176
  template <typename T, bool has_ff>
126
177
  static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
127
178
  const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
@@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
171
222
  const float * freq_factors, queue_ptr stream) {
172
223
  GGML_ASSERT(ne0 % 2 == 0);
173
224
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
174
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
225
+ const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
175
226
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
176
227
 
177
228
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
@@ -184,20 +235,22 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
184
235
  the limit. To get the device limit, query
185
236
  info::device::max_work_group_size. Adjust the work-group size if needed.
186
237
  */
187
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
188
- rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
189
- theta_scale, freq_factors, item_ct1);
190
- });
238
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
239
+ [=](sycl::nd_item<3> item_ct1) {
240
+ rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
241
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
242
+ });
191
243
  } else {
192
244
  /*
193
245
  DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
194
246
  the limit. To get the device limit, query
195
247
  info::device::max_work_group_size. Adjust the work-group size if needed.
196
248
  */
197
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
198
- rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
199
- theta_scale, freq_factors, item_ct1);
200
- });
249
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
250
+ [=](sycl::nd_item<3> item_ct1) {
251
+ rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
252
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
253
+ });
201
254
  }
202
255
  }
203
256
 
@@ -208,7 +261,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
208
261
  const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
209
262
  GGML_ASSERT(ne0 % 2 == 0);
210
263
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
211
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
264
+ const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
212
265
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
213
266
 
214
267
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
@@ -216,18 +269,54 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
216
269
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
217
270
 
218
271
  if (freq_factors == nullptr) {
219
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
220
- rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
221
- theta_scale, freq_factors, item_ct1);
272
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
273
+ [=](sycl::nd_item<3> item_ct1) {
274
+ rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
275
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
276
+ });
277
+ } else {
278
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
279
+ [=](sycl::nd_item<3> item_ct1) {
280
+ rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
281
+ attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
282
+ });
283
+ }
284
+ }
285
+
286
+ template <typename T>
287
+ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
288
+ const size_t s2, const int n_dims, const int nr, const int32_t * pos,
289
+ const float freq_scale, const float freq_base, const float ext_factor,
290
+ const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
291
+ const mrope_sections sections, queue_ptr stream) {
292
+ GGML_ASSERT(ne0 % 2 == 0);
293
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
294
+ const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
295
+ const sycl::range<3> grid_dims(1, n_blocks_y, nr);
296
+ const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
297
+
298
+ const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
299
+ // Add FP16 capability check if T could be sycl::half
300
+ if constexpr (std::is_same_v<T, sycl::half>) {
301
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
302
+ }
303
+ // launch kernel
304
+ if (freq_factors == nullptr) {
305
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
306
+ rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
307
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
222
308
  });
223
309
  } else {
224
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
225
- rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
226
- theta_scale, freq_factors, item_ct1);
310
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
311
+ rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
312
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
227
313
  });
228
314
  }
229
315
  }
230
316
 
317
+
318
+
319
+
231
320
  // rope vision
232
321
  template <typename T>
233
322
  static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
@@ -237,7 +326,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
237
326
  const mrope_sections sections, queue_ptr stream) {
238
327
  GGML_ASSERT(ne0 % 2 == 0);
239
328
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
240
- const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
329
+ const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
241
330
  const sycl::range<3> grid_dims(1, n_blocks_y, nr);
242
331
  const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
243
332
 
@@ -248,12 +337,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
248
337
  }
249
338
  // launch kernel
250
339
  if (freq_factors == nullptr) {
251
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
340
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
252
341
  rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
253
342
  corr_dims, theta_scale, freq_factors, sections, item_ct1);
254
343
  });
255
344
  } else {
256
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
345
+ sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
257
346
  rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
258
347
  corr_dims, theta_scale, freq_factors, sections, item_ct1);
259
348
  });
@@ -298,8 +387,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
298
387
  memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
299
388
 
300
389
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
390
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
301
391
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
302
392
 
393
+ if (is_mrope) {
394
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
395
+ }
396
+
397
+ if (is_vision) {
398
+ GGML_ASSERT(n_dims == ne00/2);
399
+ }
400
+
303
401
  const int32_t * pos = (const int32_t *) dst->src[1]->data;
304
402
 
305
403
  const float * freq_factors = nullptr;
@@ -326,6 +424,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
326
424
  } else {
327
425
  GGML_ABORT("fatal error");
328
426
  }
427
+ } else if (is_mrope && !is_vision) {
428
+ GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
429
+ if (dst->src[0]->type == GGML_TYPE_F16) {
430
+ rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
431
+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
432
+ freq_factors, sections, main_stream);
433
+ } else if (dst->src[0]->type == GGML_TYPE_F32) {
434
+ rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
435
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
436
+ main_stream);
437
+ } else {
438
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
439
+ }
329
440
  } else if (is_vision) {
330
441
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
331
442
  if (dst->src[0]->type == GGML_TYPE_F16) {
@@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
127
127
  const int nrows_y, const float scale, const float max_bias, const float m0,
128
128
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
129
  const size_t n_local_scratch, queue_ptr stream) {
130
- stream->submit([&](sycl::handler &cgh) {
130
+ sycl_launch(stream, [&](sycl::handler & cgh) {
131
131
  sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
132
132
 
133
- cgh.parallel_for(
134
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
133
+ sycl_parallel_for(
134
+ cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
135
135
  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
136
136
  soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137
137
  nrows_y, scale, max_bias, m0,
@@ -1,6 +1,7 @@
1
1
  #include "sycl_hw.hpp"
2
2
 
3
-
3
+ // TODO: currently not used
4
+ /*
4
5
  sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
5
6
  sycl_hw_info res;
6
7
  int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
@@ -11,3 +12,4 @@ sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
11
12
 
12
13
  return res;
13
14
  }
15
+ */
@@ -10,6 +10,8 @@
10
10
 
11
11
  namespace syclex = sycl::ext::oneapi::experimental;
12
12
 
13
+ // TODO: currently not used
14
+ /*
13
15
  struct sycl_hw_info {
14
16
  syclex::architecture arch;
15
17
  int32_t device_id;
@@ -18,6 +20,7 @@ struct sycl_hw_info {
18
20
  bool is_in_vector(std::vector<int> &vec, int item);
19
21
 
20
22
  sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
23
+ */
21
24
 
22
25
 
23
26
  #endif // SYCL_HW_HPP
@@ -45,14 +45,9 @@ static void timestep_embedding_f32_sycl(
45
45
  int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
46
46
  sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
47
47
  sycl::range<3> gridDim(1, ne00, num_blocks);
48
- stream->parallel_for(
49
- sycl::nd_range<3>(
50
- gridDim * block_dims, block_dims),
51
- [=](sycl::nd_item<3> item_ct1) {
52
- timestep_embedding_f32(
53
- x, dst, nb1, dim, max_period, item_ct1
54
- );
55
- });
48
+ sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
49
+ timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
50
+ });
56
51
  }
57
52
 
58
53
  void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {