whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -383,13 +383,31 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6
383
383
  }
384
384
  }
385
385
 
386
+ template<typename T>
387
+ static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
388
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
389
+ const int64_t j0 = (i / n) * o0 + (i % n);
390
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
391
+ dst[i] = op_gelu_erf(x[j0]) * g[j1];
392
+ }
393
+ }
394
+
395
+ template<typename T>
396
+ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
397
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
398
+ const int64_t j0 = (i / n) * o0 + (i % n);
399
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
400
+ dst[i] = op_gelu_quick(x[j0]) * g[j1];
401
+ }
402
+ }
403
+
386
404
  namespace ggml_sycl_detail {
387
405
  static void acc_f32_sycl(const float *x, const float *y, float *dst,
388
406
  const int n_elements, const int ne10, const int ne11,
389
407
  const int ne12, const int nb1, const int nb2,
390
408
  const int offset, queue_ptr stream) {
391
409
  int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
392
- sycl_parallel_for(stream,
410
+ stream->parallel_for(
393
411
  sycl::nd_range<1>(sycl::range<1>(num_blocks) *
394
412
  sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
395
413
  sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
@@ -407,8 +425,8 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
407
425
  int dst_size = ne10 * ne11 * ne12 * ne13;
408
426
  int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
409
427
  sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
410
- sycl_parallel_for<1>(
411
- stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
428
+ stream->parallel_for(
429
+ sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
412
430
  upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
413
431
  });
414
432
  }
@@ -419,7 +437,7 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
419
437
  const int ne1, const int ne2, queue_ptr stream) {
420
438
  int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
421
439
  sycl::range<3> gridDim(ne2, ne1, num_blocks);
422
- sycl_parallel_for(stream,
440
+ stream->parallel_for(
423
441
  sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
424
442
  sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
425
443
  [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
@@ -621,7 +639,7 @@ static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor
621
639
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
622
640
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
623
641
  const int num_blocks = ceil_div(k_elements, 256);
624
- sycl_parallel_for(stream,
642
+ stream->parallel_for(
625
643
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
626
644
  sycl::range<1>(256)),
627
645
  [=](sycl::nd_item<1> item_ct1) {
@@ -634,7 +652,7 @@ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor
634
652
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
635
653
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
636
654
  const int num_blocks = ceil_div(k_elements, 256);
637
- sycl_parallel_for(stream,
655
+ stream->parallel_for(
638
656
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
639
657
  sycl::range<1>(256)),
640
658
  [=](sycl::nd_item<1> item_ct1) {
@@ -647,7 +665,7 @@ static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor
647
665
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
648
666
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
649
667
  const int num_blocks = ceil_div(k_elements, 256);
650
- sycl_parallel_for(stream,
668
+ stream->parallel_for(
651
669
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
652
670
  sycl::range<1>(256)),
653
671
  [=](sycl::nd_item<1> item_ct1) {
@@ -660,7 +678,7 @@ static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tenso
660
678
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
661
679
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
662
680
  const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
663
- sycl_parallel_for(stream,
681
+ stream->parallel_for(
664
682
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
665
683
  sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
666
684
  [=](sycl::nd_item<1> item_ct1) {
@@ -673,7 +691,7 @@ static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tenso
673
691
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
674
692
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
675
693
  const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
676
- sycl_parallel_for(stream,
694
+ stream->parallel_for(
677
695
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
678
696
  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
679
697
  [=](sycl::nd_item<1> item_ct1) {
@@ -686,7 +704,7 @@ static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml
686
704
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
687
705
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
688
706
  const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
689
- sycl_parallel_for(stream,
707
+ stream->parallel_for(
690
708
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
691
709
  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
692
710
  [=](sycl::nd_item<1> item_ct1) {
@@ -699,7 +717,7 @@ static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_t
699
717
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
700
718
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
701
719
  const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
702
- sycl_parallel_for(stream,
720
+ stream->parallel_for(
703
721
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
704
722
  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
705
723
  [=](sycl::nd_item<1> item_ct1) {
@@ -712,7 +730,7 @@ static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tenso
712
730
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
713
731
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
714
732
  const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
715
- sycl_parallel_for(stream,
733
+ stream->parallel_for(
716
734
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
717
735
  sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
718
736
  [=](sycl::nd_item<1> item_ct1) {
@@ -725,7 +743,7 @@ static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tenso
725
743
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
726
744
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
727
745
  const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
728
- sycl_parallel_for(stream,
746
+ stream->parallel_for(
729
747
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
730
748
  sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
731
749
  [=](sycl::nd_item<1> item_ct1) {
@@ -738,7 +756,7 @@ static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggm
738
756
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
739
757
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
740
758
  const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
741
- sycl_parallel_for(stream,
759
+ stream->parallel_for(
742
760
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
743
761
  sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
744
762
  [=](sycl::nd_item<1> item_ct1) {
@@ -751,7 +769,7 @@ static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_
751
769
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
752
770
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
753
771
  const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
754
- sycl_parallel_for(stream,
772
+ stream->parallel_for(
755
773
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
756
774
  sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
757
775
  [=](sycl::nd_item<1> item_ct1) {
@@ -764,7 +782,7 @@ static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor
764
782
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
765
783
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
766
784
  const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
767
- sycl_parallel_for(stream,
785
+ stream->parallel_for(
768
786
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
769
787
  sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
770
788
  [=](sycl::nd_item<1> item_ct1) {
@@ -777,7 +795,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
777
795
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
778
796
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
779
797
  const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
780
- sycl_parallel_for(stream,
798
+ stream->parallel_for(
781
799
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
782
800
  sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
783
801
  [=](sycl::nd_item<1> item_ct1) {
@@ -790,7 +808,7 @@ static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor
790
808
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
791
809
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
792
810
  const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
793
- sycl_parallel_for(stream,
811
+ stream->parallel_for(
794
812
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
795
813
  sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
796
814
  [=](sycl::nd_item<1> item_ct1) {
@@ -803,7 +821,7 @@ static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tenso
803
821
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
804
822
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
805
823
  const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
806
- sycl_parallel_for(stream,
824
+ stream->parallel_for(
807
825
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
808
826
  sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
809
827
  [=](sycl::nd_item<1> item_ct1) {
@@ -816,7 +834,7 @@ static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_te
816
834
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
817
835
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
818
836
  const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
819
- sycl_parallel_for(stream,
837
+ stream->parallel_for(
820
838
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
821
839
  sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
822
840
  [=](sycl::nd_item<1> item_ct1) {
@@ -829,7 +847,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
829
847
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
830
848
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
831
849
  const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
832
- sycl_parallel_for(stream,
850
+ stream->parallel_for(
833
851
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
834
852
  sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
835
853
  [=](sycl::nd_item<1> item_ct1) {
@@ -842,7 +860,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
842
860
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
843
861
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
844
862
  const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
845
- sycl_parallel_for(stream,
863
+ stream->parallel_for(
846
864
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
847
865
  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
848
866
  [=](sycl::nd_item<1> item_ct1) {
@@ -855,7 +873,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
855
873
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
856
874
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
857
875
  const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
858
- sycl_parallel_for(stream,
876
+ stream->parallel_for(
859
877
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
860
878
  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
861
879
  [=](sycl::nd_item<1> item_ct1) {
@@ -870,7 +888,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
870
888
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
871
889
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
872
890
  const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
873
- sycl_parallel_for(stream,
891
+ stream->parallel_for(
874
892
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
875
893
  sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
876
894
  [=](sycl::nd_item<1> item_ct1) {
@@ -883,7 +901,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
883
901
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
884
902
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
885
903
  const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
886
- sycl_parallel_for(stream,
904
+ stream->parallel_for(
887
905
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
888
906
  sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
889
907
  [=](sycl::nd_item<1> item_ct1) {
@@ -917,7 +935,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
917
935
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
918
936
  [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
919
937
  const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
920
- sycl_parallel_for(stream,
938
+ stream->parallel_for(
921
939
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
922
940
  sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
923
941
  [=](sycl::nd_item<1> item_ct1) {
@@ -949,7 +967,7 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
949
967
  ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
950
968
  [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
951
969
  const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
952
- sycl_parallel_for(main_stream,
970
+ main_stream->parallel_for(
953
971
  sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
954
972
  gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
955
973
  });
@@ -960,7 +978,7 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
960
978
  ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
961
979
  [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
962
980
  const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
963
- sycl_parallel_for(main_stream,
981
+ main_stream->parallel_for(
964
982
  sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
965
983
  gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
966
984
  });
@@ -971,13 +989,35 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
971
989
  ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
972
990
  [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
973
991
  const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
974
- sycl_parallel_for(main_stream,
992
+ main_stream->parallel_for(
975
993
  sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
976
994
  gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
977
995
  });
978
996
  });
979
997
  }
980
998
 
999
+ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1000
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1001
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1002
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1003
+ main_stream->parallel_for(
1004
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1005
+ gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1006
+ });
1007
+ });
1008
+ }
1009
+
1010
+ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1011
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1012
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1013
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1014
+ main_stream->parallel_for(
1015
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1016
+ gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1017
+ });
1018
+ });
1019
+ }
1020
+
981
1021
 
982
1022
  void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
983
1023
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1118
1158
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1119
1159
  ggml_sycl_op_swiglu(ctx, dst);
1120
1160
  }
1161
+
1162
+ void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1163
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1164
+ ggml_sycl_op_geglu_erf(ctx, dst);
1165
+ }
1166
+
1167
+ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1168
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1169
+ ggml_sycl_op_geglu_quick(ctx, dst);
1170
+ }
@@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
80
80
  void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
81
81
  void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
82
82
  void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
83
+ void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
84
+ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
83
85
 
84
86
  #endif // GGML_SYCL_ELEMENTWISE_HPP
@@ -32,39 +32,28 @@ public:
32
32
  else static_assert(0);
33
33
  }
34
34
 
35
- // matrix A has m rows, k columns
36
- // matrix B has k rows, n columns
37
- // nra - number of elements to skip when moving into next row in A
38
- // nrb - number of elements to skip when moving into next row in B
39
- // nca - number of elements to skip when moving into next column in A
40
- // ncb - number of elements to skip when moving into next column in B
41
- // stride_a - number of elements to skip when moving to next A matrix
42
- // stride_b - number of elements to skip when moving to next B matrix
43
- // batches_a - number of A matrices
44
- // batches_b - number of B matrices
45
35
  static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
46
- const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
47
- const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
36
+ const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
37
+ const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
48
38
  void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
49
39
 
50
40
  auto stream = ctx.stream_dnnl(q);
51
41
  auto eng = ctx.engine_dnnl(q);
52
42
 
53
- // { # strides, # rows, # columns }
54
- dnnl::memory::dims a_dims = { batches_a, m, k };
55
- dnnl::memory::dims b_dims = { batches_b, k, n };
56
- dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
57
-
58
- // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
59
- dnnl::memory::dims a_strides = { stride_a, nra, nca };
60
- dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
61
-
43
+ dnnl::memory::dims a_dims = {batches_a, m, k };
44
+ dnnl::memory::dims a_strides = {stra2, stra1, stra0};
62
45
  const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
46
+
47
+ dnnl::memory::dims b_dims = {batches_b, k, n };
48
+ dnnl::memory::dims b_strides = {strb2, strb0, strb1};
63
49
  const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
64
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
65
50
 
51
+ dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
52
+ dnnl::memory::dims c_strides = {m*n, 1, m };
53
+ const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
66
54
  dnnl::primitive_attr primitive_attr;
67
55
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
56
+
68
57
  #ifdef GGML_SYCL_F16
69
58
  primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
70
59
  #endif
@@ -76,24 +65,23 @@ public:
76
65
 
77
66
  auto scratchpad_md = matmul_pd.scratchpad_desc();
78
67
  auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
68
+
79
69
  auto matmul_prim = dnnl::matmul(matmul_pd);
80
70
 
81
71
  std::unordered_map<int, dnnl::memory> matmul_args;
82
72
  matmul_args.insert({ DNNL_ARG_SRC, a_mem });
83
73
  matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
74
+
84
75
  matmul_args.insert({ DNNL_ARG_DST, c_mem });
85
76
  matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
86
77
 
87
78
  matmul_prim.execute(stream, matmul_args);
88
79
  }
89
80
 
90
- // matrices A and B are column major, both having k rows
91
- // matrix A has m column, matrix B has n columns
92
- // output: column major matrix C = A transposed * B
93
81
  static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
94
82
  const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
95
83
 
96
- gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
84
+ gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
97
85
  }
98
86
  };
99
87
 
@@ -118,10 +118,12 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
118
118
 
119
119
  GGML_ASSERT(ne00 % 2 == 0);
120
120
 
121
- sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
122
- k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
123
- item_ct1);
124
- });
121
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
122
+ [=](sycl::nd_item<3> item_ct1) {
123
+ k_get_rows<qk, qr, dq>(
124
+ src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
125
+ s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
126
+ });
125
127
 
126
128
  GGML_UNUSED(dst);
127
129
  GGML_UNUSED(ctx);
@@ -154,8 +156,9 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
154
156
  dpct::has_capability_or_fail(stream->get_device(),
155
157
  {sycl::aspect::fp16});
156
158
 
157
- sycl_parallel_for(
158
- stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
159
+ stream->parallel_for(
160
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
161
+ [=](sycl::nd_item<3> item_ct1) {
159
162
  k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
160
163
  s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
161
164
  });