whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,121 @@
1
+ kernel void kernel_upscale(
2
+ global const void * p_src0,
3
+ ulong off_src0,
4
+ global void * p_dst,
5
+ ulong off_dst,
6
+ ulong nb00,
7
+ ulong nb01,
8
+ ulong nb02,
9
+ ulong nb03,
10
+ int ne10,
11
+ int ne11,
12
+ int ne12,
13
+ int ne13,
14
+ float sf0,
15
+ float sf1,
16
+ float sf2,
17
+ float sf3
18
+ ) {
19
+ global const char * src_base = (global const char *)p_src0 + off_src0;
20
+ global float * dst_base = (global float *)((global char *)p_dst + off_dst);
21
+
22
+ int index = get_global_id(0);
23
+ int dst_total_elements = ne10 * ne11 * ne12 * ne13;
24
+
25
+ if (index >= dst_total_elements) {
26
+ return;
27
+ }
28
+
29
+ int i10 = index % ne10;
30
+ int i11 = (index / ne10) % ne11;
31
+ int i12 = (index / (ne10 * ne11)) % ne12;
32
+ int i13 = index / (ne10 * ne11 * ne12);
33
+
34
+ int i00 = (int)(i10 / sf0);
35
+ int i01 = (int)(i11 / sf1);
36
+ int i02 = (int)(i12 / sf2);
37
+ int i03 = (int)(i13 / sf3);
38
+
39
+ ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00;
40
+ global const float * src_element_ptr = (global const float *)(src_base + offset_src_element);
41
+
42
+ dst_base[index] = *src_element_ptr;
43
+ }
44
+
45
+ kernel void kernel_upscale_bilinear(
46
+ global const void * p_src0,
47
+ ulong off_src0,
48
+ global void * p_dst,
49
+ ulong off_dst,
50
+ ulong nb00,
51
+ ulong nb01,
52
+ ulong nb02,
53
+ ulong nb03,
54
+ int ne00_src,
55
+ int ne01_src,
56
+ int ne10_dst,
57
+ int ne11_dst,
58
+ int ne12_dst,
59
+ int ne13_dst,
60
+ float sf0,
61
+ float sf1,
62
+ float sf2,
63
+ float sf3
64
+ ) {
65
+ global const char * src_base = (global const char *)p_src0 + off_src0;
66
+ global float * dst_base = (global float *)((global char *)p_dst + off_dst);
67
+
68
+ int index = get_global_id(0);
69
+ int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
70
+
71
+ if (index >= dst_total_elements) {
72
+ return;
73
+ }
74
+
75
+ int i10_dst = index % ne10_dst;
76
+ int i11_dst = (index / ne10_dst) % ne11_dst;
77
+ int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
78
+ int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
79
+
80
+ int i02_src = (int)(i12_dst / sf2);
81
+ int i03_src = (int)(i13_dst / sf3);
82
+
83
+ const float pixel_offset = 0.5f;
84
+
85
+ float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
86
+ long y0_src = (long)floor(y_src_f);
87
+ long y1_src = y0_src + 1;
88
+
89
+ y0_src = max(0L, min(y0_src, (long)ne01_src - 1));
90
+ y1_src = max(0L, min(y1_src, (long)ne01_src - 1));
91
+
92
+ float dy = y_src_f - (float)y0_src;
93
+ dy = max(0.0f, min(dy, 1.0f));
94
+
95
+ float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
96
+ long x0_src = (long)floor(x_src_f);
97
+ long x1_src = x0_src + 1;
98
+
99
+ x0_src = max(0L, min(x0_src, (long)ne00_src - 1));
100
+ x1_src = max(0L, min(x1_src, (long)ne00_src - 1));
101
+
102
+ float dx = x_src_f - (float)x0_src;
103
+ dx = max(0.0f, min(dx, 1.0f));
104
+
105
+ global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
106
+ global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
107
+ global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
108
+ global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
109
+
110
+ const float val_a = *p_a;
111
+ const float val_b = *p_b;
112
+ const float val_c = *p_c;
113
+ const float val_d = *p_d;
114
+
115
+ float result = val_a * (1.0f - dx) * (1.0f - dy) +
116
+ val_b * dx * (1.0f - dy) +
117
+ val_c * (1.0f - dx) * dy +
118
+ val_d * dx * dy;
119
+
120
+ dst_base[index] = result;
121
+ }
@@ -568,14 +568,14 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
568
568
  }
569
569
  float iscale = nmax/(max - min);
570
570
  float scale = 1/iscale;
571
- float best_mad = 0;
571
+ float best_error = 0;
572
572
  for (int i = 0; i < n; ++i) {
573
573
  int l = nearest_int(iscale*(x[i] - min));
574
574
  L[i] = MAX(0, MIN(nmax, l));
575
575
  float diff = scale * L[i] + min - x[i];
576
576
  diff = use_mad ? fabsf(diff) : diff * diff;
577
577
  float w = weights[i];
578
- best_mad += w * diff;
578
+ best_error += w * diff;
579
579
  }
580
580
  if (nstep < 1) {
581
581
  *the_min = -min;
@@ -601,18 +601,18 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
601
601
  this_min = 0;
602
602
  this_scale = sum_xl / sum_l2;
603
603
  }
604
- float mad = 0;
604
+ float cur_error = 0;
605
605
  for (int i = 0; i < n; ++i) {
606
606
  float diff = this_scale * Laux[i] + this_min - x[i];
607
607
  diff = use_mad ? fabsf(diff) : diff * diff;
608
608
  float w = weights[i];
609
- mad += w * diff;
609
+ cur_error += w * diff;
610
610
  }
611
- if (mad < best_mad) {
611
+ if (cur_error < best_error) {
612
612
  for (int i = 0; i < n; ++i) {
613
613
  L[i] = Laux[i];
614
614
  }
615
- best_mad = mad;
615
+ best_error = cur_error;
616
616
  scale = this_scale;
617
617
  min = this_min;
618
618
  }
@@ -2425,8 +2425,6 @@ void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_REST
2425
2425
  }
2426
2426
  }
2427
2427
 
2428
- static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
2429
-
2430
2428
  void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
2431
2429
  assert(k % QK4_NL == 0);
2432
2430
  const int64_t nb = k / QK4_NL;
@@ -53,6 +53,9 @@ struct socket_t {
53
53
  }
54
54
  };
55
55
 
56
+ // macro for nicer error messages on server crash
57
+ #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
58
+
56
59
  // all RPC structures must be packed
57
60
  #pragma pack(push, 1)
58
61
  // ggml_tensor is serialized into rpc_tensor
@@ -425,7 +428,7 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
425
428
  static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
426
429
  rpc_msg_hello_rsp response;
427
430
  bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
428
- GGML_ASSERT(status);
431
+ RPC_STATUS_ASSERT(status);
429
432
  if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
430
433
  fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
431
434
  return false;
@@ -481,7 +484,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
481
484
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
482
485
  rpc_msg_free_buffer_req request = {ctx->remote_ptr};
483
486
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
484
- GGML_ASSERT(status);
487
+ RPC_STATUS_ASSERT(status);
485
488
  delete ctx;
486
489
  }
487
490
 
@@ -493,7 +496,7 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
493
496
  rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
494
497
  rpc_msg_buffer_get_base_rsp response;
495
498
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
496
- GGML_ASSERT(status);
499
+ RPC_STATUS_ASSERT(status);
497
500
  ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
498
501
  return ctx->base_ptr;
499
502
  }
@@ -545,7 +548,7 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
545
548
  request.tensor = serialize_tensor(tensor);
546
549
 
547
550
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
548
- GGML_ASSERT(status);
551
+ RPC_STATUS_ASSERT(status);
549
552
  }
550
553
  return GGML_STATUS_SUCCESS;
551
554
  }
@@ -560,7 +563,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
560
563
  request.hash = fnv_hash((const uint8_t*)data, size);
561
564
  rpc_msg_set_tensor_hash_rsp response;
562
565
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
563
- GGML_ASSERT(status);
566
+ RPC_STATUS_ASSERT(status);
564
567
  if (response.result) {
565
568
  // the server has the same data, no need to send it
566
569
  return;
@@ -573,7 +576,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
573
576
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
574
577
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
575
578
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
576
- GGML_ASSERT(status);
579
+ RPC_STATUS_ASSERT(status);
577
580
  }
578
581
 
579
582
  static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -583,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
583
586
  request.offset = offset;
584
587
  request.size = size;
585
588
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
586
- GGML_ASSERT(status);
589
+ RPC_STATUS_ASSERT(status);
587
590
  }
588
591
 
589
592
  static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -601,7 +604,7 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
601
604
  request.dst = serialize_tensor(dst);
602
605
  rpc_msg_copy_tensor_rsp response;
603
606
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
604
- GGML_ASSERT(status);
607
+ RPC_STATUS_ASSERT(status);
605
608
  return response.result;
606
609
  }
607
610
 
@@ -609,7 +612,7 @@ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
609
612
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
610
613
  rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
611
614
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
612
- GGML_ASSERT(status);
615
+ RPC_STATUS_ASSERT(status);
613
616
  }
614
617
 
615
618
  static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
@@ -635,7 +638,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
635
638
  rpc_msg_alloc_buffer_rsp response;
636
639
  auto sock = get_socket(buft_ctx->endpoint);
637
640
  bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
638
- GGML_ASSERT(status);
641
+ RPC_STATUS_ASSERT(status);
639
642
  if (response.remote_ptr != 0) {
640
643
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
641
644
  ggml_backend_rpc_buffer_interface,
@@ -650,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
650
653
  static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
651
654
  rpc_msg_get_alignment_rsp response;
652
655
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
653
- GGML_ASSERT(status);
656
+ RPC_STATUS_ASSERT(status);
654
657
  return response.alignment;
655
658
  }
656
659
 
@@ -662,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
662
665
  static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
663
666
  rpc_msg_get_max_size_rsp response;
664
667
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
665
- GGML_ASSERT(status);
668
+ RPC_STATUS_ASSERT(status);
666
669
  return response.max_size;
667
670
  }
668
671
 
@@ -683,7 +686,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
683
686
 
684
687
  rpc_msg_get_alloc_size_rsp response;
685
688
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
686
- GGML_ASSERT(status);
689
+ RPC_STATUS_ASSERT(status);
687
690
 
688
691
  return response.alloc_size;
689
692
  } else {
@@ -761,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
761
764
  rpc_msg_graph_compute_rsp response;
762
765
  auto sock = get_socket(rpc_ctx->endpoint);
763
766
  bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
764
- GGML_ASSERT(status);
767
+ RPC_STATUS_ASSERT(status);
765
768
  return (enum ggml_status)response.result;
766
769
  }
767
770
 
@@ -835,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
835
838
  static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
836
839
  rpc_msg_get_device_memory_rsp response;
837
840
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
838
- GGML_ASSERT(status);
841
+ RPC_STATUS_ASSERT(status);
839
842
  *free = response.free_mem;
840
843
  *total = response.total_mem;
841
844
  }
@@ -13,7 +13,7 @@ elseif(SUPPORTS_SYCL)
13
13
  If you expected the oneAPI Release compiler, please install oneAPI & source it, like:
14
14
  source /opt/intel/oneapi/setvars.sh")
15
15
  else()
16
- message(FATAL_ERROR, "C++ compiler lacks SYCL support.")
16
+ message(FATAL_ERROR "C++ compiler lacks SYCL support.")
17
17
  endif()
18
18
  message(STATUS "SYCL found")
19
19
  #todo: AOT
@@ -142,7 +142,7 @@ else()
142
142
  FetchContent_Declare(
143
143
  ONEMATH
144
144
  GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git
145
- GIT_TAG c255b1b4c41e2ee3059455c1f96a965d6a62568a
145
+ GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142
146
146
  )
147
147
  FetchContent_MakeAvailable(ONEMATH)
148
148
  # Create alias to match with find_package targets name
@@ -170,7 +170,7 @@ else()
170
170
  target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA)
171
171
  elseif (GGML_SYCL_TARGET STREQUAL "AMD")
172
172
  if (NOT GGML_SYCL_DEVICE_ARCH)
173
- message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
173
+ message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
174
174
  endif()
175
175
  target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas)
176
176
  target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
@@ -225,9 +225,9 @@ struct bin_bcast_sycl {
225
225
  dpct::has_capability_or_fail(stream->get_device(),
226
226
  {sycl::aspect::fp16});
227
227
 
228
- stream->parallel_for(
229
- sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230
- sycl::range<3>(1, 1, block_size),
228
+ sycl_parallel_for(
229
+ stream,
230
+ sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
231
231
  sycl::range<3>(1, 1, block_size)),
232
232
  [=](sycl::nd_item<3> item_ct1) {
233
233
  k_bin_bcast_unravel<bin_op>(
@@ -246,9 +246,8 @@ struct bin_bcast_sycl {
246
246
  dpct::has_capability_or_fail(stream->get_device(),
247
247
  {sycl::aspect::fp16});
248
248
 
249
- stream->parallel_for(
250
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
251
- [=](sycl::nd_item<3> item_ct1) {
249
+ sycl_parallel_for(
250
+ stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
252
251
  k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253
252
  ne2, ne3, ne10, ne11, ne12, ne13,
254
253
  s1, s2, s3, s01, s02, s03, s11, s12, s13,
@@ -149,8 +149,6 @@ typedef sycl::float2 dfloat2;
149
149
 
150
150
  #define MMVQ_MAX_BATCH_SIZE 8
151
151
 
152
- static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
153
-
154
152
  static int g_all_sycl_device_count = -1;
155
153
  static bool g_ggml_backend_sycl_buffer_type_initialized = false;
156
154
 
@@ -201,7 +199,7 @@ struct sycl_device_info {
201
199
  // size_t smpb; // max. shared memory per block
202
200
  bool vmm; // virtual memory support
203
201
  size_t total_vram;
204
- sycl_hw_info hw_info;
202
+ //sycl_hw_info hw_info; \\ device id and aarch, currently not used
205
203
  optimize_feature opt_feature;
206
204
  };
207
205
 
@@ -288,29 +286,6 @@ struct ggml_tensor_extra_gpu {
288
286
 
289
287
  void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
290
288
 
291
- inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
292
- optimize_feature opt;
293
-
294
- opt.reorder =
295
- (arch == syclex::architecture::intel_gpu_dg1 ||
296
- arch == syclex::architecture::intel_gpu_acm_g10 ||
297
- arch == syclex::architecture::intel_gpu_acm_g11 ||
298
- arch == syclex::architecture::intel_gpu_acm_g12 ||
299
- arch == syclex::architecture::intel_gpu_pvc ||
300
- arch == syclex::architecture::intel_gpu_pvc_vg ||
301
- arch == syclex::architecture::intel_gpu_mtl_u ||
302
- arch == syclex::architecture::intel_gpu_mtl_s ||
303
- arch == syclex::architecture::intel_gpu_mtl_h ||
304
- arch == syclex::architecture::intel_gpu_arl_u ||
305
- arch == syclex::architecture::intel_gpu_arl_s ||
306
- arch == syclex::architecture::intel_gpu_arl_h ||
307
- arch == syclex::architecture::intel_gpu_bmg_g21 ||
308
- arch == syclex::architecture::intel_gpu_lnl_m
309
- );
310
-
311
- return opt;
312
- }
313
-
314
289
  namespace sycl_ex = sycl::ext::oneapi::experimental;
315
290
  struct ggml_backend_sycl_context {
316
291
  int device;
@@ -515,9 +490,9 @@ constexpr size_t ceil_div(const size_t m, const size_t n) {
515
490
 
516
491
  bool gpu_has_xmx(sycl::device &dev);
517
492
 
518
- template <int N, class T> void debug_print_array(const std::string & prefix, const T array[N]) {
493
+ template <int N, class T> std::string debug_get_array_str(const std::string & prefix, const T array[N]) {
519
494
  if (LIKELY(!g_ggml_sycl_debug)) {
520
- return;
495
+ return "";
521
496
  }
522
497
  std::stringstream ss;
523
498
  ss << prefix << "=[";
@@ -528,29 +503,26 @@ template <int N, class T> void debug_print_array(const std::string & prefix, con
528
503
  ss << array[N - 1];
529
504
  }
530
505
  ss << "]";
531
- GGML_SYCL_DEBUG("%s", ss.str().c_str());
506
+ return ss.str();
532
507
  }
533
508
 
534
- inline void debug_print_tensor(const std::string & prefix, const ggml_tensor * tensor,
535
- const std::string & suffix = "") {
536
- if (LIKELY(!g_ggml_sycl_debug)) {
537
- return;
538
- }
539
- GGML_SYCL_DEBUG("%s=", prefix.c_str());
509
+ inline std::string debug_get_tensor_str(const std::string &prefix,
510
+ const ggml_tensor *tensor, const std::string &suffix = "") {
511
+ std::stringstream ss;
512
+ if (LIKELY(!g_ggml_sycl_debug)) { return ss.str(); }
513
+ ss << prefix.c_str() << "=";
540
514
  if (tensor) {
541
- GGML_SYCL_DEBUG("'%s':type=%s", tensor->name, ggml_type_name(tensor->type));
542
- debug_print_array<GGML_MAX_DIMS>(";ne", tensor->ne);
543
- debug_print_array<GGML_MAX_DIMS>(";nb", tensor->nb);
544
- if (!ggml_is_contiguous(tensor)) {
545
- GGML_SYCL_DEBUG(";strided");
546
- }
547
- if (ggml_is_permuted(tensor)) {
548
- GGML_SYCL_DEBUG(";permuted");
549
- }
515
+ ss << "'" << tensor->name << "':type=" << ggml_type_name(tensor->type);
516
+ ss << debug_get_array_str<GGML_MAX_DIMS>(";ne", tensor->ne);
517
+ ss << debug_get_array_str<GGML_MAX_DIMS>(";nb", tensor->nb);
518
+
519
+ if (!ggml_is_contiguous(tensor)) { ss << ";strided"; }
520
+ if (ggml_is_permuted(tensor)) { ss << ";permuted"; }
550
521
  } else {
551
- GGML_SYCL_DEBUG("nullptr");
522
+ ss << "nullptr";
552
523
  }
553
- GGML_SYCL_DEBUG("%s", suffix.c_str());
524
+ ss << suffix;
525
+ return ss.str();
554
526
  }
555
527
 
556
528
  // Use scope_op_debug_print to log operations coming from running a model
@@ -566,10 +538,10 @@ struct scope_op_debug_print {
566
538
  return;
567
539
  }
568
540
  GGML_SYCL_DEBUG("[SYCL][OP] call %s%s:", func.data(), func_suffix.data());
569
- debug_print_tensor(" dst", dst);
541
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" dst", dst).c_str());
570
542
  if (dst) {
571
543
  for (std::size_t i = 0; i < num_src; ++i) {
572
- debug_print_tensor("\tsrc" + std::to_string(i), dst->src[i]);
544
+ GGML_SYCL_DEBUG("%s", debug_get_tensor_str("\tsrc" + std::to_string(i), dst->src[i]).c_str());
573
545
  }
574
546
  }
575
547
  GGML_SYCL_DEBUG("%s\n", suffix.data());
@@ -89,33 +89,24 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
89
89
  sycl::range<3> gridDim(ne2, ne1, num_blocks);
90
90
  switch (dim) {
91
91
  case 0:
92
- stream->parallel_for(
93
- sycl::nd_range<3>(gridDim *
94
- sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
95
- sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
96
- [=](sycl::nd_item<3> item_ct1) {
97
- concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
98
- });
99
- break;
92
+ sycl_parallel_for(stream,
93
+ sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
94
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
95
+ [=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
96
+ break;
100
97
  case 1:
101
- stream->parallel_for(
102
- sycl::nd_range<3>(gridDim *
103
- sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104
- sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
105
- [=](sycl::nd_item<3> item_ct1) {
106
- concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
- });
108
- break;
98
+ sycl_parallel_for(stream,
99
+ sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
100
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
101
+ [=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
102
+ break;
109
103
  // dim >=2 will be dispatched to the default path
110
104
  default:
111
- stream->parallel_for(
112
- sycl::nd_range<3>(gridDim *
113
- sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114
- sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
115
- [=](sycl::nd_item<3> item_ct1) {
116
- concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
117
- });
118
- break;
105
+ sycl_parallel_for(stream,
106
+ sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
107
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
108
+ [=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
109
+ break;
119
110
  }
120
111
  }
121
112
 
@@ -129,33 +120,29 @@ static void concat_f32_sycl_non_cont(
129
120
  int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130
121
  uint64_t nb3, int32_t dim) {
131
122
  sycl::range<3> gridDim(ne3, ne2, ne1);
132
- stream->parallel_for(
133
- sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
134
- [=](sycl::nd_item<3> item_ct1) {
135
- int64_t i3 = item_ct1.get_group(0);
136
- int64_t i2 = item_ct1.get_group(1);
137
- int64_t i1 = item_ct1.get_group(2);
123
+ sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
124
+ int64_t i3 = item_ct1.get_group(0);
125
+ int64_t i2 = item_ct1.get_group(1);
126
+ int64_t i1 = item_ct1.get_group(2);
138
127
 
139
- int64_t o[4] = {0, 0, 0, 0};
140
- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
128
+ int64_t o[4] = { 0, 0, 0, 0 };
129
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
141
130
 
142
- const float *x;
131
+ const float * x;
143
132
 
144
- for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
145
- i0 += item_ct1.get_local_range(2)) {
133
+ for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
146
134
  if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
147
- x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
148
- (i0)*nb00);
135
+ x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
149
136
  } else {
150
- x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
151
- (i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
137
+ x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
138
+ (i0 - o[0]) * nb10);
152
139
  }
153
140
 
154
141
  float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
155
142
 
156
143
  *y = *x;
157
- }
158
- });
144
+ }
145
+ });
159
146
  }
160
147
 
161
148
  void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
@@ -59,16 +59,10 @@ static void conv_transpose_1d_f32_f32_sycl(
59
59
  const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
60
60
  const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
61
61
  const sycl::range<3> block_nums(1, 1, num_blocks);
62
- stream->parallel_for(
63
- sycl::nd_range<3>(
64
- block_nums * block_dims, block_dims),
65
- [=](sycl::nd_item<3> item_ct1) {
66
- conv_transpose_1d_kernel(
67
- s0, output_size,
68
- src0_ne0, src0_ne1, src0_ne2,
69
- src1_ne0, dst_ne0,
70
- src0, src1, dst, item_ct1);
71
- });
62
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
63
+ conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
64
+ item_ct1);
65
+ });
72
66
  }
73
67
 
74
68
  void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {