whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -41,8 +41,10 @@
41
41
  #include "ggml-sycl/element_wise.hpp"
42
42
  #include "ggml-sycl/presets.hpp"
43
43
  #include "ggml-sycl/gemm.hpp"
44
+ #include "ggml-sycl/set_rows.hpp"
44
45
  #include "ggml-sycl/sycl_hw.hpp"
45
46
  #include "ggml-sycl/getrows.hpp"
47
+ #include "ggml-sycl/quantize.hpp"
46
48
  #include "ggml.h"
47
49
 
48
50
  static bool g_sycl_loaded = false;
@@ -83,7 +85,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
83
85
 
84
86
  info.devices[i].cc =
85
87
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
86
- info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
88
+ info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
87
89
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
88
90
  }
89
91
 
@@ -1372,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
1372
1374
 
1373
1375
 
1374
1376
 
1375
- template<int QUANT_BLOCK_TILE>
1376
- static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
1377
- const sycl::nd_item<3> &item_ct1) {
1378
- const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1379
- item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
1380
-
1381
- if (ix >= kx_padded) {
1382
- return;
1383
- }
1384
-
1385
- const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1386
- item_ct1.get_local_id(1);
1387
-
1388
- const int i_padded = iy*kx_padded + ix;
1389
-
1390
- block_q8_1 * y = (block_q8_1 *) vy;
1391
-
1392
- const int ib = i_padded / QK8_1; // block index
1393
- const int iqs = i_padded % QK8_1; // quant index
1394
- typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
1395
- typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
1396
- TC zeros;
1397
- TQ qzeros;
1398
- #pragma unroll
1399
- for (int i = 0; i < QUANT_BLOCK_TILE; i++)
1400
- {
1401
- zeros[i] = 0.f;
1402
- qzeros[i] = 0;
1403
- }
1404
- const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
1405
- float sum = xi[0];
1406
- float amax = sycl::fabs(xi[0]);
1407
- #pragma unroll
1408
- for (int i = 1; i < QUANT_BLOCK_TILE; i++)
1409
- {
1410
- sum += xi[i];
1411
- amax = sycl::fmax(sycl::fabs(xi[i]), amax);
1412
- }
1413
- sum = warp_reduce_sum(sum, item_ct1);
1414
- amax = warp_reduce_max(amax, item_ct1);
1415
-
1416
- const float d = amax / 127;
1417
- TQ q = qzeros;
1418
- if (amax != 0.0f)
1419
- {
1420
- #pragma unroll
1421
- for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
1422
- q[i] = sycl::round(xi[i] / d);
1423
- }
1424
- }
1425
-
1426
- *(TQ *)&y[ib].qs[iqs] = q;
1427
-
1428
- if (iqs > 0) {
1429
- return;
1430
- }
1431
-
1432
- reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
1433
- reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1434
- }
1435
-
1436
- template <int ElementsPerWI>
1437
- static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
1438
- const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
1439
- /*
1440
- Quantizes and reorders the resultant q8 tensor in a per row fashion
1441
- Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1442
- */
1443
-
1444
- auto subgroup_id = it.get_group(0);
1445
- auto wi_id = it.get_local_id(0);
1446
-
1447
- const int num_blocks_per_row = kx / QK8_1;
1448
- auto row = subgroup_id / num_blocks_per_row;
1449
- auto col = subgroup_id % num_blocks_per_row;
1450
-
1451
- auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
1452
- auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1453
-
1454
- auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1455
- auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
1456
-
1457
- sycl::vec<float, ElementsPerWI> wi_f32_vals;
1458
- sycl::vec<int8_t, ElementsPerWI> quantized_values;
1459
-
1460
- auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1461
- wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
1462
-
1463
- float sum = 0.0f;
1464
- float amax = 0.0f;
1465
-
1466
- #pragma unroll(ElementsPerWI)
1467
- for (int i = 0; i < ElementsPerWI; i++) {
1468
- sum += wi_f32_vals[i];
1469
- amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
1470
- quantized_values[i] = 0;
1471
- }
1472
- sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
1473
- amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
1474
- float d = amax == 0 ? 1 : amax / 127;
1475
-
1476
- #pragma unroll(ElementsPerWI)
1477
- for (int i = 0; i < ElementsPerWI; i++) {
1478
- quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
1479
- }
1480
-
1481
- d = amax == 0 ? 0 : d;
1482
-
1483
- *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
1484
- if (wi_id == 0) {
1485
- *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
1486
- }
1487
- }
1488
-
1489
1377
  static void mul_mat_p021_f16_f32(
1490
1378
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1491
1379
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1545,7 +1433,7 @@ static void mul_mat_p021_f16_f32(
1545
1433
 
1546
1434
  static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1547
1435
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1548
- const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
1436
+ const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
1549
1437
  const sycl::nd_item<3> &item_ct1) {
1550
1438
 
1551
1439
  const sycl::half *x = (const sycl::half *)vx;
@@ -1556,7 +1444,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1556
1444
  item_ct1.get_local_id(0);
1557
1445
  const int channel_x = channel / channel_x_divisor;
1558
1446
 
1559
- const int nrows_y = ncols_x;
1560
1447
  const int nrows_dst = nrows_x;
1561
1448
  const int row_dst = row_x;
1562
1449
 
@@ -1575,7 +1462,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1575
1462
  const int row_y = col_x;
1576
1463
 
1577
1464
  const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1578
- const int iy = channel*nrows_y + row_y;
1465
+ const int iy = channel * channel_stride_y + row_y;
1579
1466
 
1580
1467
  const float xi =
1581
1468
  sycl::vec<sycl::half, 1>(x[ix])
@@ -1695,7 +1582,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
1695
1582
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
1696
1583
  }
1697
1584
 
1698
- static void scale_f32(const float * x, float * dst, const float scale, const int k,
1585
+ static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
1699
1586
  const sycl::nd_item<3> &item_ct1) {
1700
1587
  const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1701
1588
  item_ct1.get_local_id(2);
@@ -1704,7 +1591,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
1704
1591
  return;
1705
1592
  }
1706
1593
 
1707
- dst[i] = scale * x[i];
1594
+ dst[i] = scale * x[i] + bias;
1708
1595
  }
1709
1596
 
1710
1597
 
@@ -1770,32 +1657,6 @@ static void pool2d_nchw_kernel(
1770
1657
  o_ptr[cur_oh * ow + cur_ow] = res;
1771
1658
  }
1772
1659
 
1773
- static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1774
- bool reorder_q8_tensor, queue_ptr stream) {
1775
- if (reorder_q8_tensor) {
1776
- auto local_range = std::size_t(WARP_SIZE);
1777
- auto num_quant_blocks = ky * (kx / QK8_1);
1778
- auto global_range = num_quant_blocks * local_range;
1779
- stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
1780
- [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1781
- quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1782
- });
1783
- } else {
1784
- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1785
- const sycl::range<3> num_blocks(1, ky, block_num_x);
1786
- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1787
- static_assert(QK8_1 % WARP_SIZE == 0);
1788
- const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1789
- {
1790
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
1791
-
1792
- stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
1793
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1794
- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1795
- });
1796
- }
1797
- }
1798
- }
1799
1660
 
1800
1661
  static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1801
1662
  float *dst, const int ncols_x,
@@ -1822,7 +1683,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
1822
1683
  static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1823
1684
  const void *vx, const float *y, float *dst, const int ncols_x,
1824
1685
  const int nrows_x, const int row_stride_x, const int nchannels_x,
1825
- const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
1686
+ const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
1826
1687
 
1827
1688
  const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
1828
1689
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -1834,7 +1695,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1834
1695
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1835
1696
  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1836
1697
  mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1837
- row_stride_x, channel_stride_x,
1698
+ row_stride_x, channel_stride_x, channel_stride_y,
1838
1699
  nchannels_y / nchannels_x, item_ct1);
1839
1700
  });
1840
1701
  }
@@ -1842,7 +1703,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
1842
1703
 
1843
1704
 
1844
1705
 
1845
- static void scale_f32_sycl(const float *x, float *dst, const float scale,
1706
+ static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
1846
1707
  const int k, queue_ptr stream) {
1847
1708
  const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
1848
1709
  stream->parallel_for(
@@ -1850,7 +1711,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
1850
1711
  sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
1851
1712
  sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
1852
1713
  [=](sycl::nd_item<3> item_ct1) {
1853
- scale_f32(x, dst, scale, k, item_ct1);
1714
+ scale_f32(x, dst, scale, bias, k, item_ct1);
1854
1715
  });
1855
1716
  }
1856
1717
 
@@ -1885,12 +1746,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1885
1746
  const size_t shared_mem = ncols_pad * sizeof(int);
1886
1747
 
1887
1748
  if (order == GGML_SORT_ORDER_ASC) {
1888
- sycl_launch(stream, [&](sycl::handler & cgh) {
1749
+ stream->submit([&](sycl::handler &cgh) {
1889
1750
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1890
1751
  sycl::range<1>(shared_mem), cgh);
1891
1752
 
1892
- sycl_parallel_for(
1893
- cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1753
+ cgh.parallel_for(
1754
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1755
+ [=](sycl::nd_item<3> item_ct1) {
1894
1756
  k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
1895
1757
  x, dst, ncols, ncols_pad, item_ct1,
1896
1758
  dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -1898,12 +1760,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
1898
1760
  });
1899
1761
  });
1900
1762
  } else if (order == GGML_SORT_ORDER_DESC) {
1901
- sycl_launch(stream, [&](sycl::handler & cgh) {
1763
+ stream->submit([&](sycl::handler &cgh) {
1902
1764
  sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
1903
1765
  sycl::range<1>(shared_mem), cgh);
1904
1766
 
1905
- sycl_parallel_for(
1906
- cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1767
+ cgh.parallel_for(
1768
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1769
+ [=](sycl::nd_item<3> item_ct1) {
1907
1770
  k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
1908
1771
  x, dst, ncols, ncols_pad, item_ct1,
1909
1772
  dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -1921,47 +1784,50 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
1921
1784
  const sycl::range<3> block_nums(1, nrows, 1);
1922
1785
  const size_t shared_mem = 256 * sizeof(float);
1923
1786
 
1924
- sycl_launch(stream, [&](sycl::handler & cgh) {
1787
+ stream->submit([&](sycl::handler &cgh) {
1925
1788
  sycl::local_accessor<float, 1> shared_data(
1926
1789
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
1927
1790
  sycl::local_accessor<int, 1> shared_indices(
1928
1791
  sycl::range<1>(shared_mem/sizeof(float)), cgh);
1929
1792
 
1930
- sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
1931
- const int tid = item_ct1.get_local_id(2);
1932
- const int row = item_ct1.get_global_id(1);
1933
-
1934
- float max_val = -INFINITY;
1935
- int max_idx = -1;
1936
-
1937
- for (int col = tid; col < ncols; col += 256) {
1938
- float val = x[row * ncols + col];
1939
- if (val > max_val) {
1940
- max_val = val;
1941
- max_idx = col;
1793
+ cgh.parallel_for(
1794
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1795
+ [=](sycl::nd_item<3> item_ct1) {
1796
+ const int tid = item_ct1.get_local_id(2);
1797
+ const int row = item_ct1.get_global_id(1);
1798
+
1799
+ float max_val = -INFINITY;
1800
+ int max_idx = -1;
1801
+
1802
+ for (int col = tid; col < ncols; col += 256) {
1803
+ float val = x[row * ncols + col];
1804
+ if (val > max_val) {
1805
+ max_val = val;
1806
+ max_idx = col;
1807
+ }
1942
1808
  }
1943
- }
1944
1809
 
1945
- shared_data[tid] = max_val;
1946
- shared_indices[tid] = max_idx;
1947
- item_ct1.barrier(sycl::access::fence_space::local_space);
1810
+ shared_data[tid] = max_val;
1811
+ shared_indices[tid] = max_idx;
1812
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1948
1813
 
1949
- for (int stride = 256 / 2; stride > 0; stride >>= 1) {
1950
- if (tid < stride) {
1951
- float val1 = shared_data[tid];
1952
- float val2 = shared_data[tid + stride];
1953
- if (val2 > val1) {
1954
- shared_data[tid] = val2;
1955
- shared_indices[tid] = shared_indices[tid + stride];
1814
+ for (int stride = 256/2; stride > 0; stride >>= 1) {
1815
+ if (tid < stride) {
1816
+ float val1 = shared_data[tid];
1817
+ float val2 = shared_data[tid + stride];
1818
+ if (val2 > val1) {
1819
+ shared_data[tid] = val2;
1820
+ shared_indices[tid] = shared_indices[tid + stride];
1821
+ }
1956
1822
  }
1823
+ item_ct1.barrier(sycl::access::fence_space::local_space);
1957
1824
  }
1958
- item_ct1.barrier(sycl::access::fence_space::local_space);
1959
- }
1960
1825
 
1961
- if (tid == 0) {
1962
- dst[row] = shared_indices[0];
1963
- }
1964
- });
1826
+
1827
+ if (tid == 0) {
1828
+ dst[row] = shared_indices[0];
1829
+ }
1830
+ });
1965
1831
  });
1966
1832
  }
1967
1833
  static void diag_mask_inf_f32_sycl(const float *x, float *dst,
@@ -2123,8 +1989,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2123
1989
 
2124
1990
  #if GGML_SYCL_DNNL
2125
1991
  if (!g_ggml_sycl_disable_dnn) {
2126
- DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2127
- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
1992
+ DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
1993
+ DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2128
1994
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2129
1995
  }
2130
1996
  else
@@ -2170,8 +2036,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2170
2036
 
2171
2037
  #if GGML_SYCL_DNNL
2172
2038
  if (!g_ggml_sycl_disable_dnn) {
2173
- DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2174
- DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2039
+ DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
2040
+ DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2175
2041
  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2176
2042
  }
2177
2043
  else
@@ -2319,9 +2185,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
2319
2185
  float * dst_dd = static_cast<float *>(dst->data);
2320
2186
 
2321
2187
  float scale;
2322
- memcpy(&scale, dst->op_params, sizeof(float));
2188
+ float bias;
2189
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
2190
+ memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
2323
2191
 
2324
- scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
2192
+ scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
2325
2193
  /*
2326
2194
  DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2327
2195
  error codes. The call was replaced with 0. You need to rewrite this code.
@@ -2370,10 +2238,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
2370
2238
  peer_access_enabled = enable_peer_access;
2371
2239
  }
2372
2240
 
2241
+ template <template <int> typename quantize_f>
2373
2242
  static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2374
2243
  const ggml_tensor *src1, ggml_tensor *dst,
2375
- ggml_sycl_op_mul_mat_t op,
2376
- const bool convert_src1_to_q8_1) try {
2244
+ ggml_sycl_op_mul_mat_t op) try {
2377
2245
 
2378
2246
  GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
2379
2247
 
@@ -2468,6 +2336,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2468
2336
  }
2469
2337
  }
2470
2338
 
2339
+ constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
2340
+ no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
2471
2341
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
2472
2342
  if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
2473
2343
  continue;
@@ -2493,20 +2363,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2493
2363
  dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
2494
2364
  }
2495
2365
 
2496
- if (convert_src1_to_q8_1) {
2366
+ if constexpr(quantize_enabled) {
2497
2367
  dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2498
2368
 
2499
2369
  if (src1_on_device && src1_is_contiguous) {
2500
- bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
2501
2370
  scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2502
2371
  /*num_src=*/2, " : converting src1 to Q8_1");
2503
- quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2504
- /*
2505
- DPCT1010:90: SYCL uses exceptions to report errors and does not
2506
- use the error codes. The call was replaced with 0. You need to
2507
- rewrite this code.
2508
- */
2509
- SYCL_CHECK(0);
2372
+ try {
2373
+ quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2374
+ } catch (sycl::exception const &exc) {
2375
+ std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
2376
+ << ", line:" << __LINE__ << std::endl;
2377
+ std::exit(1);
2378
+ }
2510
2379
  }
2511
2380
  }
2512
2381
 
@@ -2522,11 +2391,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2522
2391
  // here an event is recorded that signals that the main device has finished calculating the input data
2523
2392
  if (split && used_devices > 1) {
2524
2393
  ggml_sycl_set_device(ctx.device);
2525
- /*
2526
- DPCT1024:91: The original code returned the error code that was further
2527
- consumed by the program logic. This original code was replaced with 0.
2528
- You may need to rewrite the program logic consuming the error code.
2529
- */
2530
2394
  SYCL_CHECK(CHECK_TRY_ERROR(
2531
2395
  *src0_extra->events[ctx.device][0] =
2532
2396
  ctx.stream()->ext_oneapi_submit_barrier()));
@@ -2550,11 +2414,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2550
2414
 
2551
2415
  // wait for main GPU data if necessary
2552
2416
  if (split && (i != ctx.device || is != 0)) {
2553
- /*
2554
- DPCT1009:163: SYCL uses exceptions to report errors and does not
2555
- use the error codes. The original code was commented out and a
2556
- warning string was inserted. You need to rewrite this code.
2557
- */
2558
2417
  SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
2559
2418
  {*src0_extra->events[ctx.device][0]})));
2560
2419
  }
@@ -2580,39 +2439,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2580
2439
  // copy src0, src1 to device if necessary
2581
2440
  if (src1_is_contiguous) {
2582
2441
  if (i != ctx.device) {
2583
- if (convert_src1_to_q8_1) {
2442
+ if constexpr (quantize_enabled) {
2584
2443
  char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
2585
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
2586
- src1_ddq_i, src1_ddq_i_source,
2587
- src1_ncols * src1_padded_col_size * q8_1_ts /
2588
- q8_1_bs).wait()));
2444
+ SYCL_CHECK(
2445
+ CHECK_TRY_ERROR(stream
2446
+ ->memcpy(src1_ddq_i, src1_ddq_i_source,
2447
+ src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
2448
+ .wait()));
2589
2449
  } else {
2590
-
2591
2450
  float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
2592
- src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
2451
+ src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
2593
2452
 
2594
- SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
2595
- src1_ddf_i, src1_ddf_i_source,
2596
- src1_ncols * ne10 * sizeof(float))));
2453
+ SYCL_CHECK(
2454
+ CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
2455
+ src1_ncols * ne10 * sizeof(float))));
2597
2456
  }
2598
2457
  }
2599
- } else if (src1_on_device && !src1_is_contiguous) {
2600
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
2601
- src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
2602
2458
  } else {
2603
- GGML_ABORT("fatal error");
2604
- }
2459
+ if (src1_on_device) {
2460
+ SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
2461
+ src1_col_0 + src1_ncols, stream));
2462
+ } else {
2463
+ GGML_ABORT("src1 is non-contiguous and not on device");
2464
+ }
2605
2465
 
2606
- if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2607
- scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2608
- /*num_src=*/2, " : converting src1 to Q8_1");
2609
- quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
2610
- /*
2611
- DPCT1010:92: SYCL uses exceptions to report errors and does
2612
- not use the error codes. The call was replaced with 0. You
2613
- need to rewrite this code.
2614
- */
2615
- SYCL_CHECK(0);
2466
+ if constexpr (quantize_enabled) {
2467
+ scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2468
+ /*num_src=*/2, " : converting src1 to Q8_1");
2469
+ try {
2470
+ quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
2471
+ src1_padded_col_size, stream);
2472
+ } catch (const sycl::exception & exc) {
2473
+ std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
2474
+ << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2475
+ std::exit(1);
2476
+ }
2477
+ }
2616
2478
  }
2617
2479
 
2618
2480
  if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
@@ -2624,12 +2486,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2624
2486
  // do the computation
2625
2487
  SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
2626
2488
  dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
2627
- /*
2628
- DPCT1010:93: SYCL uses exceptions to report errors and does not
2629
- use the error codes. The call was replaced with 0. You need to
2630
- rewrite this code.
2631
- */
2632
- SYCL_CHECK(0);
2633
2489
 
2634
2490
  // copy dst to host or other device if necessary
2635
2491
  if (!dst_on_device) {
@@ -2660,12 +2516,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2660
2516
 
2661
2517
  // add event for the main device to wait on until other device is done
2662
2518
  if (split && (i != ctx.device || is != 0)) {
2663
- /*
2664
- DPCT1024:94: The original code returned the error code that
2665
- was further consumed by the program logic. This original
2666
- code was replaced with 0. You may need to rewrite the
2667
- program logic consuming the error code.
2668
- */
2669
2519
  SYCL_CHECK(CHECK_TRY_ERROR(
2670
2520
  *src0_extra->events[i][is] =
2671
2521
  stream->ext_oneapi_submit_barrier()));
@@ -2764,6 +2614,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2764
2614
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2765
2615
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2766
2616
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2617
+ GGML_ASSERT(src1->ne[1] == 1);
2618
+ GGML_ASSERT(src1->ne[3] == 1);
2767
2619
 
2768
2620
  const int64_t ne00 = src0->ne[0];
2769
2621
  const int64_t ne01 = src0->ne[1];
@@ -2773,6 +2625,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2773
2625
  const int64_t nb02 = src0->nb[2];
2774
2626
 
2775
2627
  const int64_t ne12 = src1->ne[2];
2628
+ const int64_t nb11 = src1->nb[1];
2776
2629
 
2777
2630
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2778
2631
  queue_ptr main_stream = ctx.stream();
@@ -2783,8 +2636,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
2783
2636
 
2784
2637
  const int64_t row_stride_x = nb01 / sizeof(sycl::half);
2785
2638
  const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
2639
+ const int64_t channel_stride_y = nb11 / sizeof(float);
2786
2640
 
2787
- ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
2641
+ ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
2788
2642
  }
2789
2643
  catch (sycl::exception const &exc) {
2790
2644
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -2838,8 +2692,11 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2838
2692
  float * dst_ddf = static_cast<float *>(dst->data);
2839
2693
 
2840
2694
  const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2695
+ const size_t type_size_src0 = ggml_type_size(src0->type);
2841
2696
  const size_t type_size_src1 = ggml_type_size(src1->type);
2842
- GGML_ASSERT(nb10 == type_size_src1);
2697
+
2698
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
2699
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
2843
2700
 
2844
2701
  // SRC1 strides
2845
2702
  int64_t s11 = nb11 / type_size_src1;
@@ -2851,16 +2708,47 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2851
2708
  if (src1->type != GGML_TYPE_F16) {
2852
2709
  scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
2853
2710
  " : converting src1 to fp16");
2854
- const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2855
- GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2711
+
2712
+ // iterate tensor dims and find the slowest moving dim and stride
2713
+ int last_dim=0;
2714
+ int last_str=0;
2715
+ size_t largest_str=0;
2716
+ for(int i = 0; i< 4; i++){
2717
+ // last stride is always the largest
2718
+ if(src1->nb[i] == largest_str){
2719
+ if(src1->ne[last_dim] == 1){
2720
+ last_str = i;
2721
+ last_dim = i;
2722
+ }
2723
+ }
2724
+ if(src1->nb[i] > largest_str){
2725
+ largest_str = src1->nb[i];
2726
+ last_str = i;
2727
+ last_dim = i;
2728
+ }
2729
+
2730
+ }
2731
+ #if GGML_SYCL_DNNL
2732
+ // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
2733
+ const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
2734
+ src1_f16_alloc.alloc(ne_src1);
2735
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2736
+ GGML_ASSERT(to_fp16_sycl != nullptr);
2737
+ to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
2738
+ # else
2856
2739
  const int64_t ne_src1 = ggml_nelements(src1);
2857
2740
  src1_f16_alloc.alloc(ne_src1);
2741
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2742
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2858
2743
  to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2744
+ #endif
2859
2745
 
2860
2746
  src1_f16 = src1_f16_alloc.get();
2861
2747
  s11 = ne10;
2862
2748
  s12 = ne11 * s11;
2863
2749
  s13 = ne12 * s12;
2750
+
2751
+ is_src1_cont_2 = true;
2864
2752
  }
2865
2753
 
2866
2754
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
@@ -2889,48 +2777,115 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2889
2777
 
2890
2778
  #if GGML_SYCL_DNNL
2891
2779
  if (!g_ggml_sycl_disable_dnn) {
2892
- auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2893
- (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2894
-
2895
- DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2896
- src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2897
- src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2898
- dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2899
- };
2900
-
2901
- if (r2 == 1 && r3 == 1) {
2902
- if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2903
- dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2904
- }
2905
- else {
2906
- for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2907
- const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2908
- const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2909
- float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2910
- dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2780
+ int64_t str_a0 = nb00 / type_size_src0;
2781
+ int64_t str_a1 = nb01 / type_size_src0;
2782
+ int64_t str_a2 = nb02 / type_size_src0;
2783
+
2784
+ int64_t str_b0 = nb10 / type_size_src1;
2785
+ int64_t str_b1 = nb11 / type_size_src1;
2786
+ int64_t str_b2 = nb12 / type_size_src1;
2787
+
2788
+ auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
2789
+ const sycl::half *src1, float *dst,
2790
+ int64_t a0, int64_t a1, int64_t batcha,
2791
+ int64_t /*b0*/, int64_t b1, int64_t batchb,
2792
+ int64_t sa0, int64_t sa1, int64_t sa2,
2793
+ int64_t sb0, int64_t sb1, int64_t sb2,
2794
+ int64_t sd2) {
2795
+ bool supported_broadcast = batchb == batcha ? true
2796
+ : batchb == 1 || batcha == 1 ? true
2797
+ : false;
2798
+ if (supported_broadcast) {
2799
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
2800
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
2801
+ DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
2802
+ DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
2803
+ } else {
2804
+ // iterate over batches from smaller set of matrices (matrix 0)
2805
+ int64_t batches0 = batcha;
2806
+ int64_t batches1 = batchb;
2807
+
2808
+ if (batches0 > batches1) {
2809
+ int64_t num_mul_mats = batches1;
2810
+ int64_t sub_batch = batches0 / num_mul_mats;
2811
+ // src0 is batched and bigger, shift and multiply with src1
2812
+ for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
2813
+ const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
2814
+ const sycl::half *src1_shifted = src1 + (sb2 * i0);
2815
+ float *dst_shifted = dst + (sd2 * i0 * sub_batch);
2816
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2817
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2818
+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2819
+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2820
+ queue, sub_batch, 1);
2821
+ }
2822
+ } else {
2823
+ int64_t num_mul_mats = batches0;
2824
+ int64_t sub_batch = batches1 / num_mul_mats;
2825
+ // src1 is batched and bigger, shift and multiply with src0
2826
+ for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
2827
+ const sycl::half *src0_shifted = src0 + (sa2 * i1);
2828
+ const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
2829
+ float *dst_shifted = dst + (sd2 * i1 * sub_batch);
2830
+ DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2831
+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2832
+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2833
+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2834
+ queue, 1, sub_batch);
2835
+ }
2836
+ }
2911
2837
  }
2912
- }
2913
- } else {
2914
- // iterate over batches from smaller set of matrices (matrix 0)
2915
- for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2916
- for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2917
- const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2918
- const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2919
- float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2920
- dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2838
+ };
2839
+
2840
+ const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
2841
+ const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
2842
+ const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
2843
+ const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
2844
+ if (cont_batches_dim2_a && cont_batches_dim2_b) {
2845
+ // A batch is considered contiguous if the dimension 2 is not strided
2846
+ int64_t batches0 = ne02 * ne03;
2847
+ int64_t batches1 = ne12 * ne13;
2848
+ launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2849
+ ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
2850
+ str_b2, nb2 / sizeof(float));
2851
+ } else if (cont_batches_dim3_a && cont_batches_dim3_b) {
2852
+ // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
2853
+ int64_t batches0 = ne02 * ne03;
2854
+ int64_t batches1 = ne12 * ne13;
2855
+ int64_t str_a3 = nb03 / type_size_src0;
2856
+ int64_t str_b3 = nb13 / type_size_src1;
2857
+ launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2858
+ ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
2859
+ str_b3, nb2 / sizeof(float));
2860
+ } else {
2861
+ for (int64_t b_a = 0; b_a < ne03; b_a++) {
2862
+ const sycl::half *src0_f16_shifted
2863
+ = src0_f16 + (nb03 * b_a / type_size_src0);
2864
+ const sycl::half *src1_f16_shifted
2865
+ = src1_f16 + (nb13 * b_a / type_size_src1);
2866
+ float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
2867
+ int64_t batches0 = ne02;
2868
+ int64_t batches1 = ne12;
2869
+ launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
2870
+ ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
2871
+ str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
2921
2872
  }
2922
2873
  }
2923
- }
2874
+
2924
2875
  }
2925
2876
  else
2926
2877
  #endif
2927
2878
  {
2928
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2879
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
2880
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
2881
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
2882
+ const int64_t smb = ne12 == 1 ? s13 : s12;
2883
+
2929
2884
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2930
2885
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2931
2886
  oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2932
- src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2933
- src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2887
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
2888
+ src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
2934
2889
  mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2935
2890
  } else {
2936
2891
  const int ne23 = ne12 * ne13;
@@ -2945,7 +2900,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2945
2900
  void ** ptrs_dst_get = ptrs_dst.get();
2946
2901
  size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2947
2902
  size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2948
- sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2903
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2949
2904
  k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2950
2905
  nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2951
2906
  });
@@ -3260,26 +3215,27 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3260
3215
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
3261
3216
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3262
3217
  }
3263
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3218
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
3264
3219
  // KQV single-batch
3265
3220
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3266
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3221
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
3267
3222
  // KQ + KQV multi-batch
3268
3223
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3269
3224
  } else if (use_dequantize_mul_mat_vec) {
3270
- constexpr bool convert_src1_to_q8_1 = false;
3271
3225
  opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3272
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3226
+ ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
3273
3227
  } else if (use_mul_mat_vec_q) {
3274
- constexpr bool convert_src1_to_q8_1 = true;
3275
3228
  opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3276
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3229
+ ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3230
+ if (extra && extra->optimized_feature.reorder) {
3231
+ ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3232
+ } else {
3233
+ ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3234
+ }
3277
3235
  } else if (use_mul_mat_q) {
3278
- constexpr bool convert_src1_to_q8_1 = true;
3279
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3236
+ ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
3280
3237
  } else {
3281
- constexpr bool convert_src1_to_q8_1 = false;
3282
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3238
+ ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
3283
3239
  }
3284
3240
  }
3285
3241
 
@@ -3446,10 +3402,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3446
3402
  SYCL_CHECK(CHECK_TRY_ERROR(
3447
3403
  stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
3448
3404
 
3405
+ const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
3406
+ assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
3407
+
3449
3408
  {
3450
- sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
3409
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
3451
3410
  sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
3452
- sycl_launch(stream, [&](sycl::handler & cgh) {
3411
+ stream->submit([&](sycl::handler &cgh) {
3453
3412
  sycl::local_accessor<int, 0> src1_row_acc(cgh);
3454
3413
 
3455
3414
  char *__restrict src1_contiguous_get =
@@ -3461,8 +3420,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3461
3420
  size_t ids_nb_ct6 = ids->nb[1];
3462
3421
  size_t ids_nb_ct7 = ids->nb[0];
3463
3422
 
3464
- sycl_parallel_for(
3465
- cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3423
+ cgh.parallel_for(
3424
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3425
+ [=](sycl::nd_item<3> item_ct1) {
3466
3426
  k_copy_src1_to_contiguous(
3467
3427
  src1_original, src1_contiguous_get,
3468
3428
  dev_cur_src1_row_get,
@@ -3491,16 +3451,17 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3491
3451
  ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3492
3452
 
3493
3453
  {
3494
- sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
3454
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
3495
3455
  sycl::range<3> grid_dims(1, 1, num_src1_rows);
3496
- sycl_launch(stream, [&](sycl::handler & cgh) {
3456
+ stream->submit([&](sycl::handler &cgh) {
3497
3457
  const char *__restrict dst_contiguous_get =
3498
3458
  dst_contiguous.get();
3499
3459
  const mmid_row_mapping *__restrict dev_row_mapping_get =
3500
3460
  dev_row_mapping.get();
3501
3461
 
3502
- sycl_parallel_for(
3503
- cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
3462
+ cgh.parallel_for(
3463
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
3464
+ [=](sycl::nd_item<3> item_ct1) {
3504
3465
  k_copy_dst_from_contiguous(dst_original,
3505
3466
  dst_contiguous_get,
3506
3467
  dev_row_mapping_get,
@@ -3603,6 +3564,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3603
3564
  case GGML_OP_GET_ROWS:
3604
3565
  ggml_sycl_get_rows(ctx, dst);
3605
3566
  break;
3567
+ case GGML_OP_SET_ROWS:
3568
+ ggml_sycl_op_set_rows(ctx, dst);
3569
+ break;
3606
3570
  case GGML_OP_DUP:
3607
3571
  ggml_sycl_dup(ctx, dst);
3608
3572
  break;
@@ -3613,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3613
3577
  case GGML_OP_SUB:
3614
3578
  ggml_sycl_sub(ctx, dst);
3615
3579
  break;
3580
+ case GGML_OP_COUNT_EQUAL:
3581
+ ggml_sycl_count_equal(ctx, dst);
3582
+ break;
3616
3583
  case GGML_OP_ACC:
3617
3584
  ggml_sycl_acc(ctx, dst);
3618
3585
  break;
@@ -3687,6 +3654,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3687
3654
  case GGML_GLU_OP_SWIGLU:
3688
3655
  ggml_sycl_swiglu(ctx, dst);
3689
3656
  break;
3657
+ case GGML_GLU_OP_GEGLU_ERF:
3658
+ ggml_sycl_geglu_erf(ctx, dst);
3659
+ break;
3660
+ case GGML_GLU_OP_GEGLU_QUICK:
3661
+ ggml_sycl_geglu_quick(ctx, dst);
3662
+ break;
3690
3663
  default:
3691
3664
  return false;
3692
3665
  }
@@ -4100,6 +4073,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
4100
4073
  /* .graph_compute = */ ggml_backend_sycl_graph_compute,
4101
4074
  /* .event_record = */ ggml_backend_sycl_event_record,
4102
4075
  /* .event_wait = */ ggml_backend_sycl_event_wait,
4076
+ /* .graph_optimize = */ NULL,
4103
4077
  };
4104
4078
 
4105
4079
  static ggml_guid_t ggml_backend_sycl_guid() {
@@ -4232,6 +4206,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4232
4206
  case GGML_GLU_OP_REGLU:
4233
4207
  case GGML_GLU_OP_GEGLU:
4234
4208
  case GGML_GLU_OP_SWIGLU:
4209
+ case GGML_GLU_OP_GEGLU_ERF:
4210
+ case GGML_GLU_OP_GEGLU_QUICK:
4235
4211
  return ggml_is_contiguous_1(op->src[0]);
4236
4212
  default:
4237
4213
  return false;
@@ -4240,15 +4216,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4240
4216
  case GGML_OP_MUL_MAT:
4241
4217
  case GGML_OP_MUL_MAT_ID:
4242
4218
  {
4243
- struct ggml_tensor * a;
4244
- struct ggml_tensor * b;
4245
- if (op->op == GGML_OP_MUL_MAT) {
4246
- a = op->src[0];
4247
- b = op->src[1];
4248
- } else {
4249
- a = op->src[2];
4250
- b = op->src[1];
4251
- }
4219
+ struct ggml_tensor * a = op->src[0];
4220
+ struct ggml_tensor * b = op->src[1];
4221
+
4252
4222
  if (a->ne[3] != b->ne[3]) {
4253
4223
  return false;
4254
4224
  }
@@ -4263,7 +4233,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4263
4233
  }
4264
4234
  }
4265
4235
  ggml_type src0_type = op->src[0]->type;
4266
- if (src0_type == GGML_TYPE_BF16) {
4236
+ if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
4237
+ // TODO: support MXFP4
4238
+ // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4239
+ return false;
4240
+ }
4241
+ // TODO: The configuration below needs more work to be supported with oneDNN
4242
+ if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
4243
+ return false;
4244
+ }
4245
+ // TODO: This specific configuration can fail with oneDNN and needs more debugging
4246
+ if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
4247
+ a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
4267
4248
  return false;
4268
4249
  }
4269
4250
  return true;
@@ -4285,6 +4266,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4285
4266
  return false;
4286
4267
  }
4287
4268
  }
4269
+ case GGML_OP_SET_ROWS:
4270
+ {
4271
+ return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
4272
+ op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
4273
+ op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
4274
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
4275
+ }
4276
+ break;
4288
4277
  case GGML_OP_CPY:
4289
4278
  {
4290
4279
  ggml_type src0_type = op->src[0]->type;
@@ -4370,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4370
4359
  case GGML_OP_ADD:
4371
4360
  case GGML_OP_ADD1:
4372
4361
  case GGML_OP_SUB:
4362
+ case GGML_OP_COUNT_EQUAL:
4373
4363
  case GGML_OP_MUL:
4374
4364
  case GGML_OP_DIV:
4375
4365
  case GGML_OP_REPEAT:
@@ -4386,29 +4376,44 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4386
4376
  return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4387
4377
  #endif
4388
4378
  case GGML_OP_NORM:
4389
- case GGML_OP_RMS_NORM:
4390
4379
  return true;
4391
4380
  case GGML_OP_L2_NORM:
4392
4381
  case GGML_OP_GROUP_NORM:
4393
4382
  return ggml_is_contiguous(op->src[0]);
4383
+ case GGML_OP_RMS_NORM:
4384
+ return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
4394
4385
  case GGML_OP_SCALE:
4395
4386
  return true;
4396
4387
  case GGML_OP_CONT:
4397
4388
  return op->src[0]->type != GGML_TYPE_BF16;
4398
- case GGML_OP_DIAG_MASK_INF:
4399
4389
  case GGML_OP_SOFT_MAX:
4400
- return true;
4390
+ // TODO: support batching
4391
+ if (op->src[0]->ne[3] != 1) {
4392
+ return false;
4393
+ }
4394
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
4395
+ if (op->src[2]) {
4396
+ return false;
4397
+ }
4398
+ // TODO: support broadcast
4399
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14435
4400
+ return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
4401
+ case GGML_OP_DIAG_MASK_INF:
4401
4402
  case GGML_OP_ROPE:
4402
4403
  case GGML_OP_IM2COL:
4403
4404
  return true;
4404
4405
  case GGML_OP_UPSCALE:
4405
4406
  return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4406
- case GGML_OP_POOL_2D:
4407
4407
  case GGML_OP_SUM:
4408
4408
  case GGML_OP_SUM_ROWS:
4409
4409
  case GGML_OP_ARGSORT:
4410
+ return ggml_is_contiguous(op->src[0]);
4411
+ case GGML_OP_POOL_2D:
4410
4412
  case GGML_OP_ACC:
4413
+ return true;
4411
4414
  case GGML_OP_PAD:
4415
+ return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
4416
+ (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
4412
4417
  case GGML_OP_LEAKY_RELU:
4413
4418
  case GGML_OP_TIMESTEP_EMBEDDING:
4414
4419
  case GGML_OP_RWKV_WKV6:
@@ -4619,10 +4624,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
4619
4624
  };
4620
4625
 
4621
4626
  ggml_backend_t sycl_backend = new ggml_backend {
4622
- /* .guid = */ ggml_backend_sycl_guid(),
4623
- /* .interface = */ ggml_backend_sycl_interface,
4624
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
4625
- /* .context = */ ctx
4627
+ /* .guid = */ ggml_backend_sycl_guid(),
4628
+ /* .iface = */ ggml_backend_sycl_interface,
4629
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
4630
+ /* .context = */ ctx
4626
4631
  };
4627
4632
 
4628
4633
  return sycl_backend;