whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,4 +1,5 @@
1
1
  #include "rope.hpp"
2
+ #include "convert.hpp"
2
3
  #include "ggml-sycl/common.hpp"
3
4
  #include "ggml.h"
4
5
 
@@ -15,355 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
15
16
  return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
16
17
  }
17
18
 
18
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
19
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
20
- static void rope_yarn(
21
- float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
22
- float * cos_theta, float * sin_theta) {
23
- // Get n-d rotational scaling corrected for extrapolation
19
+ template <bool forward>
20
+ static void rope_yarn(const float theta_extrap, const float freq_scale,
21
+ const rope_corr_dims corr_dims, const int64_t i0,
22
+ const float ext_factor, float mscale, float &cos_theta,
23
+ float &sin_theta) {
24
24
  float theta_interp = freq_scale * theta_extrap;
25
25
  float theta = theta_interp;
26
26
  if (ext_factor != 0.0f) {
27
- float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
27
+ float ramp_mix =
28
+ rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
28
29
  theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
29
30
 
30
- // Get n-d magnitude scaling corrected for interpolation
31
31
  mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
32
32
  }
33
- *cos_theta = sycl::cos(theta) * mscale;
34
- *sin_theta = sycl::sin(theta) * mscale;
33
+ cos_theta = sycl::cos(theta) * mscale;
34
+ sin_theta = sycl::sin(theta) * mscale;
35
+ if (!forward) {
36
+ sin_theta *= -1.0f;
37
+ }
35
38
  }
36
39
 
37
- template <typename T, bool has_ff>
38
- static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39
- const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41
- const sycl::nd_item<3> & item_ct1) {
42
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
43
-
44
- if (i0 >= ne0) {
40
+ template <bool forward, bool has_ff, typename T, typename D>
41
+ static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
42
+ const int ne02, const int s01, const int s02,
43
+ const int s03, const int s1, const int s2, const int s3,
44
+ const int n_dims, const int32_t *pos,
45
+ const float freq_scale, const float ext_factor,
46
+ const float attn_factor, const rope_corr_dims corr_dims,
47
+ const float theta_scale, const float *freq_factors,
48
+ const int64_t *row_indices, const int set_rows_stride) {
49
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
50
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
51
+ item_ct1.get_local_id(1));
52
+
53
+ if (i0 >= ne00) {
45
54
  return;
46
55
  }
47
56
 
48
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
57
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
58
+ item_ct1.get_local_id(2);
49
59
 
50
- const int row0 = row % ne1;
51
- const int channel0 = row / ne1;
60
+ const uint32_t i3 = row_dst / (ne01 * ne02);
61
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
62
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
52
63
 
53
- const int i = row * ne0 + i0;
54
- const int i2 = channel0 * s2 + row0 * s1 + i0;
64
+ int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
65
+ const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
66
+
67
+ if (set_rows_stride != 0) {
68
+ idst = i1 * s1 + i0;
69
+ idst += row_indices[i2] * set_rows_stride;
70
+ }
55
71
 
72
+ const auto &store_coaelsced = [&](float x0, float x1) {
73
+ if constexpr (std::is_same_v<float, D>) {
74
+ sycl::float2 v = sycl::float2(x0, x1);
75
+ ggml_sycl_memcpy_1<8>(dst + idst, &v);
76
+ } else if constexpr (std::is_same_v<sycl::half, D>) {
77
+ sycl::half2 v = sycl::half2(x0, x1);
78
+ ggml_sycl_memcpy_1<4>(dst + idst, &v);
79
+ }
80
+ };
56
81
  if (i0 >= n_dims) {
57
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
82
+ store_coaelsced(x[ix + 0], x[ix + 1]);
58
83
  return;
59
84
  }
60
85
 
61
- const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
86
+ const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
62
87
 
63
88
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
64
89
 
65
90
  float cos_theta;
66
91
  float sin_theta;
67
92
 
68
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
93
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
94
+ ext_factor, attn_factor, cos_theta, sin_theta);
69
95
 
70
- const float x0 = x[i2 + 0];
71
- const float x1 = x[i2 + 1];
96
+ const float x0 = x[ix + 0];
97
+ const float x1 = x[ix + 1];
72
98
 
73
- dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
74
- dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
99
+ store_coaelsced(x0 * cos_theta - x1 * sin_theta,
100
+ x0 * sin_theta + x1 * cos_theta);
75
101
  }
76
102
 
77
- template <typename T, bool has_ff>
78
- static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
79
- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
80
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
81
- const sycl::nd_item<3> & item_ct1) {
82
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
83
-
84
- if (i0 >= ne0) {
103
+ template <bool forward, bool has_ff, typename T, typename D>
104
+ static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
105
+ const int ne02, const int s01, const int s02,
106
+ const int s03, const int s1, const int s2, const int s3,
107
+ const int n_dims, const int32_t *pos,
108
+ const float freq_scale, const float ext_factor,
109
+ const float attn_factor, const rope_corr_dims corr_dims,
110
+ const float theta_scale, const float *freq_factors,
111
+ const int64_t *row_indices, const int set_rows_stride) {
112
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
113
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
114
+ item_ct1.get_local_id(1));
115
+
116
+ if (i0 >= ne00) {
85
117
  return;
86
118
  }
87
119
 
88
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
120
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
121
+ item_ct1.get_local_id(2);
122
+
123
+ const uint32_t i3 = row_dst / (ne01 * ne02);
124
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
125
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
89
126
 
90
- const int row0 = row % ne1;
91
- const int channel0 = row / ne1;
127
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
128
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
92
129
 
93
- const int i = row * ne0 + i0 / 2;
94
- const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
130
+ if (set_rows_stride != 0) {
131
+ idst = i1 * s1 + i0 / 2;
132
+ idst += row_indices[i2] * set_rows_stride;
133
+ }
95
134
 
96
135
  if (i0 >= n_dims) {
97
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
136
+ dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
137
+ dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
138
+
98
139
  return;
99
140
  }
100
141
 
101
- const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
142
+ const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
102
143
 
103
144
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
104
145
 
105
146
  float cos_theta;
106
147
  float sin_theta;
107
148
 
108
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
149
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
150
+ ext_factor, attn_factor, cos_theta, sin_theta);
109
151
 
110
- const float x0 = x[i2 + 0];
111
- const float x1 = x[i2 + n_dims / 2];
152
+ const float x0 = x[ix + 0];
153
+ const float x1 = x[ix + n_dims / 2];
112
154
 
113
- dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
114
- dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
155
+ dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
156
+ dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
115
157
  }
116
158
 
117
- template <typename T, bool has_ff>
118
- static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
119
- const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
120
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
121
- const float theta_scale, const float * freq_factors, const mrope_sections sections,
122
- const sycl::nd_item<3> & item_ct1) {
123
- // get index pos
124
- const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
125
- if (i0 >= ne0) {
159
+ template <bool forward, bool has_ff, typename T>
160
+ static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
161
+ const int ne02, const int s01, const int s02,
162
+ const int s03, const int s1, const int s2, const int s3,
163
+ const int n_dims, const int32_t *pos,
164
+ const float freq_scale, const float ext_factor,
165
+ const float attn_factor, const rope_corr_dims corr_dims,
166
+ const float theta_scale, const float *freq_factors,
167
+ const mrope_sections sections, const bool is_imrope) {
168
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
169
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
170
+ item_ct1.get_local_id(1));
171
+
172
+ if (i0 >= ne00) {
126
173
  return;
127
174
  }
128
- const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
129
175
 
130
- const int row_x = row_dst % ne1;
131
- const int channel_x = row_dst / ne1;
132
- const int idst = (row_dst * ne0) + (i0 / 2);
133
- const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
176
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
177
+ item_ct1.get_local_id(2);
178
+
179
+ const uint32_t i3 = row_dst / (ne01 * ne02);
180
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
181
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
182
+
183
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
184
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
134
185
 
135
186
  if (i0 >= n_dims) {
136
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
187
+ dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
188
+ dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
189
+
137
190
  return;
138
191
  }
139
192
 
140
- const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
193
+ const int sect_dims =
194
+ sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
141
195
  const int sec_w = sections.v[1] + sections.v[0];
142
196
  const int sector = (i0 / 2) % sect_dims;
143
197
 
144
-
145
198
  float theta_base = 0.0;
146
- if (sector < sections.v[0]) {
147
- theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
148
- }
149
- else if (sector >= sections.v[0] && sector < sec_w) {
150
- theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
151
- }
152
- else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
153
- theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
154
- }
155
- else if (sector >= sec_w + sections.v[2]) {
156
- theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
199
+ if (is_imrope) {
200
+ if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
201
+ theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
202
+ } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
203
+ theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
204
+ } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
205
+ theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
206
+ } else {
207
+ theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
208
+ }
209
+ } else {
210
+ if (sector < sections.v[0]) {
211
+ theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
212
+ } else if (sector >= sections.v[0] && sector < sec_w) {
213
+ theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
214
+ } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
215
+ theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
216
+ } else if (sector >= sec_w + sections.v[2]) {
217
+ theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
218
+ }
157
219
  }
158
220
 
159
221
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
160
- float cos_theta;
161
- float sin_theta;
162
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
163
- const float x0 = x[ix + 0];
164
- const float x1 = x[ix + n_dims/2];
165
222
 
166
- // store results in dst
167
- dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
168
- dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
169
- }
223
+ float cos_theta;
224
+ float sin_theta;
170
225
 
226
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
227
+ ext_factor, attn_factor, cos_theta, sin_theta);
171
228
 
229
+ const float x0 = x[ix + 0];
230
+ const float x1 = x[ix + n_dims / 2];
172
231
 
173
- template <typename T, bool has_ff>
174
- static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
175
- const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
176
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
177
- const float theta_scale, const float * freq_factors, const mrope_sections sections,
178
- const sycl::nd_item<3> & item_ct1) {
179
- // get index pos
180
- const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
181
- if (i0 >= ne0) {
232
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
233
+ dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
234
+ }
235
+
236
+ template <bool forward, bool has_ff, typename T>
237
+ static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
238
+ const int ne02, const int s01, const int s02,
239
+ const int s03, const int s1, const int s2, const int s3,
240
+ const int n_dims, const int32_t *pos,
241
+ const float freq_scale, const float ext_factor,
242
+ const float attn_factor, const rope_corr_dims corr_dims,
243
+ const float theta_scale, const float *freq_factors,
244
+ const mrope_sections sections) {
245
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
246
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
247
+ item_ct1.get_local_id(1));
248
+
249
+ if (i0 >= ne00) {
182
250
  return;
183
251
  }
184
- const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
185
- const int row_x = row_dst % ne1;
186
- const int channel_x = row_dst / ne1;
187
- const int idst = (row_dst * ne0) + (i0 / 2);
188
- const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
252
+
253
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
254
+ item_ct1.get_local_id(2);
255
+
256
+ const uint32_t i3 = row_dst / (ne01 * ne02);
257
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
258
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
259
+
260
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
261
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
189
262
 
190
263
  const int sect_dims = sections.v[0] + sections.v[1];
191
- const int sector = (i0 / 2) % sect_dims;
264
+ const int sec_w = sections.v[1] + sections.v[0];
265
+ const int sector = (i0 / 2) % sect_dims;
192
266
 
193
- float theta_base = 0.0f;
267
+ float theta_base = 0.0;
194
268
  if (sector < sections.v[0]) {
195
269
  const int p = sector;
196
- theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
197
- } else {
198
- // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
270
+ theta_base = pos[i2] * dpct::pow(theta_scale, p);
271
+ } else if (sector >= sections.v[0] && sector < sec_w) {
199
272
  const int p = sector - sections.v[0];
200
- theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
273
+ theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
201
274
  }
202
275
 
203
276
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
204
- float cos_theta;
205
- float sin_theta;
206
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
277
+
278
+ float cos_theta;
279
+ float sin_theta;
280
+
281
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
282
+ ext_factor, attn_factor, cos_theta, sin_theta);
283
+
207
284
  const float x0 = x[ix + 0];
208
285
  const float x1 = x[ix + n_dims];
209
286
 
210
- // store results in dst
211
- dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
287
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
212
288
  dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
213
289
  }
214
290
 
215
- template <typename T>
216
- static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
217
- const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
218
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
219
- const float * freq_factors, queue_ptr stream) {
220
- GGML_ASSERT(ne0 % 2 == 0);
221
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
222
- const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
223
- const sycl::range<3> block_nums(1, num_blocks_x, nr);
291
+ template <bool forward, typename T, typename D>
292
+ static void
293
+ rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
294
+ const int ne02, const int s01, const int s02, const int s03,
295
+ const int s1, const int s2, const int s3, const int n_dims,
296
+ const int nr, const int32_t *pos, const float freq_scale,
297
+ const float freq_base, const float ext_factor,
298
+ const float attn_factor, const rope_corr_dims corr_dims,
299
+ const float *freq_factors, const int64_t *row_indices,
300
+ const int set_rows_stride, dpct::queue_ptr stream) {
301
+ GGML_ASSERT(ne00 % 2 == 0);
302
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
303
+ const int n_blocks_x =
304
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
305
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
224
306
 
225
307
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
226
308
 
227
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
228
-
229
309
  if (freq_factors == nullptr) {
230
- /*
231
- DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
232
- the limit. To get the device limit, query
233
- info::device::max_work_group_size. Adjust the work-group size if needed.
234
- */
235
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
236
- rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
237
- theta_scale, freq_factors, item_ct1);
238
- });
310
+ stream->parallel_for(
311
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
312
+ [=](sycl::nd_item<3> item_ct1) {
313
+ GGML_UNUSED(item_ct1);
314
+ rope_norm<forward, false>(
315
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
316
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
317
+ theta_scale, freq_factors, row_indices, set_rows_stride);
318
+ });
239
319
  } else {
240
- /*
241
- DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
242
- the limit. To get the device limit, query
243
- info::device::max_work_group_size. Adjust the work-group size if needed.
244
- */
245
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
246
- rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
247
- theta_scale, freq_factors, item_ct1);
248
- });
320
+ stream->parallel_for(
321
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
322
+ [=](sycl::nd_item<3> item_ct1) {
323
+ GGML_UNUSED(item_ct1);
324
+ rope_norm<forward, true>(
325
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
326
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
327
+ theta_scale, freq_factors, row_indices, set_rows_stride);
328
+ });
249
329
  }
250
330
  }
251
331
 
252
- template <typename T>
253
- static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
254
- const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
255
- const float freq_base, const float ext_factor, const float attn_factor,
256
- const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
257
- GGML_ASSERT(ne0 % 2 == 0);
258
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
259
- const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
260
- const sycl::range<3> block_nums(1, num_blocks_x, nr);
332
+ template <bool forward, typename T, typename D>
333
+ static void
334
+ rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
335
+ const int ne02, const int s01, const int s02, const int s03,
336
+ const int s1, const int s2, const int s3, const int n_dims,
337
+ const int nr, const int32_t *pos, const float freq_scale,
338
+ const float freq_base, const float ext_factor,
339
+ const float attn_factor, const rope_corr_dims corr_dims,
340
+ const float *freq_factors, const int64_t *row_indices,
341
+ const int set_rows_stride, dpct::queue_ptr stream) {
342
+ GGML_ASSERT(ne00 % 2 == 0);
343
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
344
+ const int n_blocks_x =
345
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
346
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
261
347
 
262
348
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
263
349
 
264
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
265
-
266
350
  if (freq_factors == nullptr) {
267
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
268
- rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
269
- theta_scale, freq_factors, item_ct1);
270
- });
351
+ stream->parallel_for(
352
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
353
+ [=](sycl::nd_item<3> item_ct1) {
354
+ GGML_UNUSED(item_ct1);
355
+ rope_neox<forward, false>(
356
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
357
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
358
+ theta_scale, freq_factors, row_indices, set_rows_stride);
359
+ });
271
360
  } else {
272
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
273
- rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
274
- theta_scale, freq_factors, item_ct1);
275
- });
361
+ stream->parallel_for(
362
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
363
+ [=](sycl::nd_item<3> item_ct1) {
364
+ GGML_UNUSED(item_ct1);
365
+ rope_neox<forward, true>(
366
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
367
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
368
+ theta_scale, freq_factors, row_indices, set_rows_stride);
369
+ });
276
370
  }
277
371
  }
278
372
 
279
- template <typename T>
280
- static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
281
- const size_t s2, const int n_dims, const int nr, const int32_t * pos,
282
- const float freq_scale, const float freq_base, const float ext_factor,
283
- const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
284
- const mrope_sections sections, queue_ptr stream) {
285
- GGML_ASSERT(ne0 % 2 == 0);
286
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
287
- const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
288
- const sycl::range<3> grid_dims(1, n_blocks_y, nr);
289
- const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
290
-
291
- const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
292
- // Add FP16 capability check if T could be sycl::half
293
- if constexpr (std::is_same_v<T, sycl::half>) {
294
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
295
- }
296
- // launch kernel
373
+ template <bool forward, typename T>
374
+ static void
375
+ rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
376
+ const int ne02, const int s01, const int s02, const int s03,
377
+ const int s1, const int s2, const int s3, const int n_dims,
378
+ const int nr, const int32_t *pos, const float freq_scale,
379
+ const float freq_base, const float ext_factor,
380
+ const float attn_factor, const rope_corr_dims corr_dims,
381
+ const float *freq_factors, const mrope_sections sections,
382
+ const bool is_imrope, dpct::queue_ptr stream) {
383
+ GGML_ASSERT(ne00 % 2 == 0);
384
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
385
+ const int n_blocks_x =
386
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
387
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
388
+
389
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
390
+
297
391
  if (freq_factors == nullptr) {
298
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
299
- rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
300
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
301
- });
392
+ stream->parallel_for(
393
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
394
+ [=](sycl::nd_item<3> item_ct1) {
395
+ GGML_UNUSED(item_ct1);
396
+ rope_multi<forward, false, T>(
397
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
398
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
399
+ theta_scale, freq_factors, sections, is_imrope);
400
+ });
302
401
  } else {
303
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
304
- rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
305
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
306
- });
402
+ stream->parallel_for(
403
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
404
+ [=](sycl::nd_item<3> item_ct1) {
405
+ GGML_UNUSED(item_ct1);
406
+ rope_multi<forward, true, T>(
407
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
408
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
409
+ theta_scale, freq_factors, sections, is_imrope);
410
+ });
307
411
  }
308
412
  }
309
413
 
414
+ template <bool forward, typename T>
415
+ static void
416
+ rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
417
+ const int ne02, const int s01, const int s02, const int s03,
418
+ const int s1, const int s2, const int s3, const int n_dims,
419
+ const int nr, const int32_t *pos, const float freq_scale,
420
+ const float freq_base, const float ext_factor,
421
+ const float attn_factor, const rope_corr_dims corr_dims,
422
+ const float *freq_factors, const mrope_sections sections,
423
+ dpct::queue_ptr stream) {
424
+ GGML_ASSERT(ne00 % 2 == 0);
425
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
426
+ const int n_blocks_x =
427
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
428
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
310
429
 
430
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
311
431
 
312
-
313
- // rope vision
314
- template <typename T>
315
- static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
316
- const size_t s2, const int n_dims, const int nr, const int32_t * pos,
317
- const float freq_scale, const float freq_base, const float ext_factor,
318
- const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
319
- const mrope_sections sections, queue_ptr stream) {
320
- GGML_ASSERT(ne0 % 2 == 0);
321
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
322
- const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
323
- const sycl::range<3> grid_dims(1, n_blocks_y, nr);
324
- const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
325
-
326
- const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
327
- // Add FP16 capability check if T could be sycl::half
328
- if constexpr (std::is_same_v<T, sycl::half>) {
329
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
330
- }
331
- // launch kernel
332
432
  if (freq_factors == nullptr) {
333
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
334
- rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
335
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
336
- });
433
+ stream->parallel_for(
434
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
435
+ [=](sycl::nd_item<3> item_ct1) {
436
+ GGML_UNUSED(item_ct1);
437
+ rope_vision<forward, false, T>(
438
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
439
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
440
+ theta_scale, freq_factors, sections);
441
+ });
337
442
  } else {
338
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
339
- rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
340
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
341
- });
443
+ stream->parallel_for(
444
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
445
+ [=](sycl::nd_item<3> item_ct1) {
446
+ GGML_UNUSED(item_ct1);
447
+ rope_vision<forward, true, T>(
448
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
449
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
450
+ theta_scale, freq_factors, sections);
451
+ });
342
452
  }
343
453
  }
344
454
 
345
- inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
455
+ template <bool forward>
456
+ void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
457
+ const ggml_tensor *set_rows = nullptr) {
458
+ const ggml_tensor *src0 = dst->src[0];
459
+ const ggml_tensor *src1 = dst->src[1];
460
+ const ggml_tensor *src2 = dst->src[2];
461
+
462
+ const float *src0_d = (const float *)src0->data;
463
+ const float *src1_d = (const float *)src1->data;
464
+
465
+ void *dst_d = dst->data;
466
+ const int64_t *row_indices = nullptr;
467
+ ggml_type dst_type = dst->type;
468
+ int set_rows_stride = 0;
469
+
470
+ if (set_rows != nullptr) {
471
+ GGML_ASSERT(forward);
472
+ dst_d = set_rows->data;
473
+ row_indices = (const int64_t *)set_rows->src[1]->data;
474
+ dst_type = set_rows->type;
475
+ set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
476
+ }
477
+ dpct::queue_ptr stream = ctx.stream();
478
+
479
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
480
+ GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
481
+ GGML_ASSERT(src0->type == dst->type ||
482
+ (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
346
483
 
347
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
348
- GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
349
- GGML_ASSERT(dst->src[0]->type == dst->type);
350
- const int64_t ne00 = dst->src[0]->ne[0]; // head dims
351
- const int64_t ne01 = dst->src[0]->ne[1]; // num heads
352
- const int64_t ne02 = dst->src[0]->ne[2]; // num heads
353
- const int64_t nr = ggml_nrows(dst->src[0]);
484
+ const int64_t ne00 = src0->ne[0]; // head dims
485
+ const int64_t ne01 = src0->ne[1]; // num heads
486
+ const int64_t ne02 = src0->ne[2]; // num heads
487
+ const int64_t nr = ggml_nrows(src0);
354
488
 
355
- const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
356
- const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
489
+ const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
490
+ const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
491
+ const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
357
492
 
493
+ const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
494
+ const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
495
+ const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
358
496
 
359
- //const int n_past = ((int32_t *) dst->op_params)[0];
360
- const int n_dims = ((int32_t *) dst->op_params)[1];
361
- const int mode = ((int32_t *) dst->op_params)[2];
362
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
363
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
497
+ const int n_dims = ((int32_t *)dst->op_params)[1];
498
+ const int mode = ((int32_t *)dst->op_params)[2];
499
+ const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
364
500
  mrope_sections sections;
365
501
 
366
- // RoPE alteration for extended context
367
502
  float freq_base;
368
503
  float freq_scale;
369
504
  float ext_factor;
@@ -371,95 +506,136 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
371
506
  float beta_fast;
372
507
  float beta_slow;
373
508
 
374
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
375
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
376
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
377
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
378
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
379
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
380
- memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
509
+ memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
510
+ memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
511
+ memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
512
+ memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
513
+ memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
514
+ memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
515
+ memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
381
516
 
382
517
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
383
518
  const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
519
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
384
520
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
385
521
 
386
522
  if (is_mrope) {
387
- GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
523
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
524
+ sections.v[2] > 0);
388
525
  }
389
526
 
390
527
  if (is_vision) {
391
- GGML_ASSERT(n_dims == ne00/2);
528
+ GGML_ASSERT(n_dims == ne00 / 2);
392
529
  }
393
530
 
394
- const int32_t * pos = (const int32_t *) dst->src[1]->data;
531
+ const int32_t *pos = (const int32_t *)src1_d;
395
532
 
396
- const float * freq_factors = nullptr;
397
- if (dst->src[2] != nullptr) {
398
- freq_factors = (const float *) dst->src[2]->data;
533
+ const float *freq_factors = nullptr;
534
+ if (src2 != nullptr) {
535
+ freq_factors = (const float *)src2->data;
399
536
  }
400
537
 
401
538
  rope_corr_dims corr_dims;
402
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
403
-
404
- dpct::queue_ptr main_stream = ctx.stream();
405
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
539
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
540
+ beta_slow, corr_dims.v);
406
541
 
407
542
  // compute
408
543
  if (is_neox) {
409
544
  GGML_SYCL_DEBUG("%s: neox path\n", __func__);
410
- if (dst->src[0]->type == GGML_TYPE_F32) {
411
- rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
412
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
413
- } else if (dst->src[0]->type == GGML_TYPE_F16) {
414
- rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
415
- n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
416
- main_stream);
545
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
546
+ rope_neox_sycl<forward, float, float>(
547
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
548
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
549
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
550
+ set_rows_stride, stream);
551
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
552
+ rope_neox_sycl<forward, float, sycl::half>(
553
+ (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
554
+ s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
555
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
556
+ row_indices, set_rows_stride, stream);
557
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
558
+ rope_neox_sycl<forward, sycl::half, sycl::half>(
559
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
560
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
561
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
562
+ row_indices, set_rows_stride, stream);
417
563
  } else {
418
- GGML_ABORT("fatal error");
564
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
419
565
  }
420
566
  } else if (is_mrope && !is_vision) {
421
567
  GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
422
- if (dst->src[0]->type == GGML_TYPE_F16) {
423
- rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
424
- s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
425
- freq_factors, sections, main_stream);
426
- } else if (dst->src[0]->type == GGML_TYPE_F32) {
427
- rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
428
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
429
- main_stream);
568
+ if (src0->type == GGML_TYPE_F32) {
569
+ rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
570
+ ne00, ne01, ne02, s01, s02, s03, s1, s2,
571
+ s3, n_dims, nr, pos, freq_scale, freq_base,
572
+ ext_factor, attn_factor, corr_dims,
573
+ freq_factors, sections, is_imrope, stream);
574
+ } else if (src0->type == GGML_TYPE_F16) {
575
+ rope_multi_sycl<forward>(
576
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
577
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
578
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
579
+ sections, is_imrope, stream);
430
580
  } else {
431
581
  GGML_ABORT("Fatal error: Tensor type unsupported!");
432
582
  }
433
583
  } else if (is_vision) {
434
584
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
435
- if (dst->src[0]->type == GGML_TYPE_F16) {
436
- rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
437
- s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
438
- freq_factors, sections, main_stream);
439
- } else if (dst->src[0]->type == GGML_TYPE_F32) {
440
- rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
441
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
442
- main_stream);
585
+ if (src0->type == GGML_TYPE_F32) {
586
+ rope_vision_sycl<forward>(
587
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
588
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
589
+ ext_factor, attn_factor, corr_dims, freq_factors, sections,
590
+ stream);
591
+ } else if (src0->type == GGML_TYPE_F16) {
592
+ rope_vision_sycl<forward>(
593
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
594
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
595
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
596
+ sections, stream);
443
597
  } else {
444
598
  GGML_ABORT("Fatal error: Tensor type unsupported!");
445
599
  }
446
600
  } else {
447
601
  GGML_SYCL_DEBUG("%s: norm path\n", __func__);
448
- if (dst->src[0]->type == GGML_TYPE_F32) {
449
- rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
450
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
451
- } else if (dst->src[0]->type == GGML_TYPE_F16) {
452
- rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
453
- n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
454
- main_stream);
602
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
603
+ rope_norm_sycl<forward, float, float>(
604
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
605
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
606
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
607
+ set_rows_stride, stream);
608
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
609
+ rope_norm_sycl<forward, float, sycl::half>(
610
+ (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
611
+ s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
612
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
613
+ row_indices, set_rows_stride, stream);
614
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
615
+ rope_norm_sycl<forward, sycl::half, sycl::half>(
616
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
617
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
618
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
619
+ row_indices, set_rows_stride, stream);
455
620
  } else {
456
- GGML_ABORT("fatal error");
621
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
457
622
  }
458
623
  }
459
624
  }
460
625
 
461
- void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
626
+ void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
462
627
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
463
- ggml_sycl_op_rope(ctx, dst);
628
+
629
+ ggml_sycl_op_rope_impl<true>(ctx, dst);
464
630
  }
465
631
 
632
+ void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
633
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
634
+ ggml_sycl_op_rope_impl<false>(ctx, dst);
635
+ }
636
+
637
+ void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
638
+ ggml_tensor *set_rows) {
639
+ scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
640
+ ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
641
+ }