whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -1,4 +1,5 @@
1
1
  #include "rope.hpp"
2
+ #include "convert.hpp"
2
3
  #include "ggml-sycl/common.hpp"
3
4
  #include "ggml.h"
4
5
 
@@ -15,367 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
15
16
  return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
16
17
  }
17
18
 
18
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
19
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
20
- static void rope_yarn(
21
- float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
22
- float * cos_theta, float * sin_theta) {
23
- // Get n-d rotational scaling corrected for extrapolation
19
+ template <bool forward>
20
+ static void rope_yarn(const float theta_extrap, const float freq_scale,
21
+ const rope_corr_dims corr_dims, const int64_t i0,
22
+ const float ext_factor, float mscale, float &cos_theta,
23
+ float &sin_theta) {
24
24
  float theta_interp = freq_scale * theta_extrap;
25
25
  float theta = theta_interp;
26
26
  if (ext_factor != 0.0f) {
27
- float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
27
+ float ramp_mix =
28
+ rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
28
29
  theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
29
30
 
30
- // Get n-d magnitude scaling corrected for interpolation
31
31
  mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
32
32
  }
33
- *cos_theta = sycl::cos(theta) * mscale;
34
- *sin_theta = sycl::sin(theta) * mscale;
33
+ cos_theta = sycl::cos(theta) * mscale;
34
+ sin_theta = sycl::sin(theta) * mscale;
35
+ if (!forward) {
36
+ sin_theta *= -1.0f;
37
+ }
35
38
  }
36
39
 
37
- template <typename T, bool has_ff>
38
- static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39
- const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41
- const sycl::nd_item<3> & item_ct1) {
42
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
43
-
44
- if (i0 >= ne0) {
40
+ template <bool forward, bool has_ff, typename T, typename D>
41
+ static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
42
+ const int ne02, const int s01, const int s02,
43
+ const int s03, const int s1, const int s2, const int s3,
44
+ const int n_dims, const int32_t *pos,
45
+ const float freq_scale, const float ext_factor,
46
+ const float attn_factor, const rope_corr_dims corr_dims,
47
+ const float theta_scale, const float *freq_factors,
48
+ const int64_t *row_indices, const int set_rows_stride) {
49
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
50
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
51
+ item_ct1.get_local_id(1));
52
+
53
+ if (i0 >= ne00) {
45
54
  return;
46
55
  }
47
56
 
48
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
57
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
58
+ item_ct1.get_local_id(2);
49
59
 
50
- const int row0 = row % ne1;
51
- const int channel0 = row / ne1;
60
+ const uint32_t i3 = row_dst / (ne01 * ne02);
61
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
62
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
52
63
 
53
- const int i = row * ne0 + i0;
54
- const int i2 = channel0 * s2 + row0 * s1 + i0;
64
+ int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
65
+ const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
66
+
67
+ if (set_rows_stride != 0) {
68
+ idst = i1 * s1 + i0;
69
+ idst += row_indices[i2] * set_rows_stride;
70
+ }
55
71
 
72
+ const auto &store_coaelsced = [&](float x0, float x1) {
73
+ if constexpr (std::is_same_v<float, D>) {
74
+ sycl::float2 v = sycl::float2(x0, x1);
75
+ ggml_sycl_memcpy_1<8>(dst + idst, &v);
76
+ } else if constexpr (std::is_same_v<sycl::half, D>) {
77
+ sycl::half2 v = sycl::half2(x0, x1);
78
+ ggml_sycl_memcpy_1<4>(dst + idst, &v);
79
+ }
80
+ };
56
81
  if (i0 >= n_dims) {
57
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
82
+ store_coaelsced(x[ix + 0], x[ix + 1]);
58
83
  return;
59
84
  }
60
85
 
61
- const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
86
+ const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
62
87
 
63
88
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
64
89
 
65
90
  float cos_theta;
66
91
  float sin_theta;
67
92
 
68
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
93
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
94
+ ext_factor, attn_factor, cos_theta, sin_theta);
69
95
 
70
- const float x0 = x[i2 + 0];
71
- const float x1 = x[i2 + 1];
96
+ const float x0 = x[ix + 0];
97
+ const float x1 = x[ix + 1];
72
98
 
73
- dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
74
- dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
99
+ store_coaelsced(x0 * cos_theta - x1 * sin_theta,
100
+ x0 * sin_theta + x1 * cos_theta);
75
101
  }
76
102
 
77
- template <typename T, bool has_ff>
78
- static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
79
- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
80
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
81
- const sycl::nd_item<3> & item_ct1) {
82
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
83
-
84
- if (i0 >= ne0) {
103
+ template <bool forward, bool has_ff, typename T, typename D>
104
+ static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
105
+ const int ne02, const int s01, const int s02,
106
+ const int s03, const int s1, const int s2, const int s3,
107
+ const int n_dims, const int32_t *pos,
108
+ const float freq_scale, const float ext_factor,
109
+ const float attn_factor, const rope_corr_dims corr_dims,
110
+ const float theta_scale, const float *freq_factors,
111
+ const int64_t *row_indices, const int set_rows_stride) {
112
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
113
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
114
+ item_ct1.get_local_id(1));
115
+
116
+ if (i0 >= ne00) {
85
117
  return;
86
118
  }
87
119
 
88
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
120
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
121
+ item_ct1.get_local_id(2);
89
122
 
90
- const int row0 = row % ne1;
91
- const int channel0 = row / ne1;
123
+ const uint32_t i3 = row_dst / (ne01 * ne02);
124
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
125
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
92
126
 
93
- const int i = row * ne0 + i0 / 2;
94
- const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
127
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
128
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
129
+
130
+ if (set_rows_stride != 0) {
131
+ idst = i1 * s1 + i0 / 2;
132
+ idst += row_indices[i2] * set_rows_stride;
133
+ }
95
134
 
96
135
  if (i0 >= n_dims) {
97
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
136
+ dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
137
+ dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
138
+
98
139
  return;
99
140
  }
100
141
 
101
- const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
142
+ const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
102
143
 
103
144
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
104
145
 
105
146
  float cos_theta;
106
147
  float sin_theta;
107
148
 
108
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
149
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
150
+ ext_factor, attn_factor, cos_theta, sin_theta);
109
151
 
110
- const float x0 = x[i2 + 0];
111
- const float x1 = x[i2 + n_dims / 2];
152
+ const float x0 = x[ix + 0];
153
+ const float x1 = x[ix + n_dims / 2];
112
154
 
113
- dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
114
- dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
155
+ dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
156
+ dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
115
157
  }
116
158
 
117
- template <typename T, bool has_ff>
118
- static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
119
- const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
120
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
121
- const float theta_scale, const float * freq_factors, const mrope_sections sections,
122
- const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
123
- // get index pos
124
- const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
125
- if (i0 >= ne0) {
159
+ template <bool forward, bool has_ff, typename T>
160
+ static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
161
+ const int ne02, const int s01, const int s02,
162
+ const int s03, const int s1, const int s2, const int s3,
163
+ const int n_dims, const int32_t *pos,
164
+ const float freq_scale, const float ext_factor,
165
+ const float attn_factor, const rope_corr_dims corr_dims,
166
+ const float theta_scale, const float *freq_factors,
167
+ const mrope_sections sections, const bool is_imrope) {
168
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
169
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
170
+ item_ct1.get_local_id(1));
171
+
172
+ if (i0 >= ne00) {
126
173
  return;
127
174
  }
128
- const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
129
175
 
130
- const int row_x = row_dst % ne1;
131
- const int channel_x = row_dst / ne1;
132
- const int idst = (row_dst * ne0) + (i0 / 2);
133
- const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
176
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
177
+ item_ct1.get_local_id(2);
178
+
179
+ const uint32_t i3 = row_dst / (ne01 * ne02);
180
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
181
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
182
+
183
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
184
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
134
185
 
135
186
  if (i0 >= n_dims) {
136
- *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
187
+ dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
188
+ dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
189
+
137
190
  return;
138
191
  }
139
192
 
140
- const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
193
+ const int sect_dims =
194
+ sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
141
195
  const int sec_w = sections.v[1] + sections.v[0];
142
196
  const int sector = (i0 / 2) % sect_dims;
143
197
 
144
-
145
198
  float theta_base = 0.0;
146
199
  if (is_imrope) {
147
- if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
148
- theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
149
- } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
150
- theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
151
- } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
152
- theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
200
+ if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
201
+ theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
202
+ } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
203
+ theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
204
+ } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
205
+ theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
153
206
  } else {
154
- theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
207
+ theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
155
208
  }
156
209
  } else {
157
210
  if (sector < sections.v[0]) {
158
- theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
159
- }
160
- else if (sector >= sections.v[0] && sector < sec_w) {
161
- theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
162
- }
163
- else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
164
- theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
165
- }
166
- else if (sector >= sec_w + sections.v[2]) {
167
- theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
211
+ theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
212
+ } else if (sector >= sections.v[0] && sector < sec_w) {
213
+ theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
214
+ } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
215
+ theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
216
+ } else if (sector >= sec_w + sections.v[2]) {
217
+ theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
168
218
  }
169
219
  }
170
220
 
171
221
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
172
- float cos_theta;
173
- float sin_theta;
174
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
175
- const float x0 = x[ix + 0];
176
- const float x1 = x[ix + n_dims/2];
177
222
 
178
- // store results in dst
179
- dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
180
- dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
181
- }
223
+ float cos_theta;
224
+ float sin_theta;
182
225
 
226
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
227
+ ext_factor, attn_factor, cos_theta, sin_theta);
183
228
 
229
+ const float x0 = x[ix + 0];
230
+ const float x1 = x[ix + n_dims / 2];
231
+
232
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
233
+ dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
234
+ }
184
235
 
185
- template <typename T, bool has_ff>
186
- static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
187
- const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
188
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
189
- const float theta_scale, const float * freq_factors, const mrope_sections sections,
190
- const sycl::nd_item<3> & item_ct1) {
191
- // get index pos
192
- const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
193
- if (i0 >= ne0) {
236
+ template <bool forward, bool has_ff, typename T>
237
+ static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
238
+ const int ne02, const int s01, const int s02,
239
+ const int s03, const int s1, const int s2, const int s3,
240
+ const int n_dims, const int32_t *pos,
241
+ const float freq_scale, const float ext_factor,
242
+ const float attn_factor, const rope_corr_dims corr_dims,
243
+ const float theta_scale, const float *freq_factors,
244
+ const mrope_sections sections) {
245
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
246
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
247
+ item_ct1.get_local_id(1));
248
+
249
+ if (i0 >= ne00) {
194
250
  return;
195
251
  }
196
- const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
197
- const int row_x = row_dst % ne1;
198
- const int channel_x = row_dst / ne1;
199
- const int idst = (row_dst * ne0) + (i0 / 2);
200
- const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
252
+
253
+ const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
254
+ item_ct1.get_local_id(2);
255
+
256
+ const uint32_t i3 = row_dst / (ne01 * ne02);
257
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
258
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
259
+
260
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
261
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
201
262
 
202
263
  const int sect_dims = sections.v[0] + sections.v[1];
203
- const int sector = (i0 / 2) % sect_dims;
264
+ const int sec_w = sections.v[1] + sections.v[0];
265
+ const int sector = (i0 / 2) % sect_dims;
204
266
 
205
- float theta_base = 0.0f;
267
+ float theta_base = 0.0;
206
268
  if (sector < sections.v[0]) {
207
269
  const int p = sector;
208
- theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
209
- } else {
210
- // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
270
+ theta_base = pos[i2] * dpct::pow(theta_scale, p);
271
+ } else if (sector >= sections.v[0] && sector < sec_w) {
211
272
  const int p = sector - sections.v[0];
212
- theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
273
+ theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
213
274
  }
214
275
 
215
276
  const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
216
- float cos_theta;
217
- float sin_theta;
218
- rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
277
+
278
+ float cos_theta;
279
+ float sin_theta;
280
+
281
+ rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
282
+ ext_factor, attn_factor, cos_theta, sin_theta);
283
+
219
284
  const float x0 = x[ix + 0];
220
285
  const float x1 = x[ix + n_dims];
221
286
 
222
- // store results in dst
223
- dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
287
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
224
288
  dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
225
289
  }
226
290
 
227
- template <typename T>
228
- static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
229
- const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
230
- const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
231
- const float * freq_factors, queue_ptr stream) {
232
- GGML_ASSERT(ne0 % 2 == 0);
233
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
234
- const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
235
- const sycl::range<3> block_nums(1, num_blocks_x, nr);
291
+ template <bool forward, typename T, typename D>
292
+ static void
293
+ rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
294
+ const int ne02, const int s01, const int s02, const int s03,
295
+ const int s1, const int s2, const int s3, const int n_dims,
296
+ const int nr, const int32_t *pos, const float freq_scale,
297
+ const float freq_base, const float ext_factor,
298
+ const float attn_factor, const rope_corr_dims corr_dims,
299
+ const float *freq_factors, const int64_t *row_indices,
300
+ const int set_rows_stride, dpct::queue_ptr stream) {
301
+ GGML_ASSERT(ne00 % 2 == 0);
302
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
303
+ const int n_blocks_x =
304
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
305
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
236
306
 
237
307
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
238
308
 
239
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
240
-
241
309
  if (freq_factors == nullptr) {
242
- /*
243
- DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
244
- the limit. To get the device limit, query
245
- info::device::max_work_group_size. Adjust the work-group size if needed.
246
- */
247
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
248
- rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
249
- theta_scale, freq_factors, item_ct1);
250
- });
310
+ stream->parallel_for(
311
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
312
+ [=](sycl::nd_item<3> item_ct1) {
313
+ GGML_UNUSED(item_ct1);
314
+ rope_norm<forward, false>(
315
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
316
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
317
+ theta_scale, freq_factors, row_indices, set_rows_stride);
318
+ });
251
319
  } else {
252
- /*
253
- DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
254
- the limit. To get the device limit, query
255
- info::device::max_work_group_size. Adjust the work-group size if needed.
256
- */
257
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
258
- rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
259
- theta_scale, freq_factors, item_ct1);
260
- });
320
+ stream->parallel_for(
321
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
322
+ [=](sycl::nd_item<3> item_ct1) {
323
+ GGML_UNUSED(item_ct1);
324
+ rope_norm<forward, true>(
325
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
326
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
327
+ theta_scale, freq_factors, row_indices, set_rows_stride);
328
+ });
261
329
  }
262
330
  }
263
331
 
264
- template <typename T>
265
- static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
266
- const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
267
- const float freq_base, const float ext_factor, const float attn_factor,
268
- const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
269
- GGML_ASSERT(ne0 % 2 == 0);
270
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
271
- const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
272
- const sycl::range<3> block_nums(1, num_blocks_x, nr);
332
+ template <bool forward, typename T, typename D>
333
+ static void
334
+ rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
335
+ const int ne02, const int s01, const int s02, const int s03,
336
+ const int s1, const int s2, const int s3, const int n_dims,
337
+ const int nr, const int32_t *pos, const float freq_scale,
338
+ const float freq_base, const float ext_factor,
339
+ const float attn_factor, const rope_corr_dims corr_dims,
340
+ const float *freq_factors, const int64_t *row_indices,
341
+ const int set_rows_stride, dpct::queue_ptr stream) {
342
+ GGML_ASSERT(ne00 % 2 == 0);
343
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
344
+ const int n_blocks_x =
345
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
346
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
273
347
 
274
348
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
275
349
 
276
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
277
-
278
350
  if (freq_factors == nullptr) {
279
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
280
- rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
281
- theta_scale, freq_factors, item_ct1);
282
- });
351
+ stream->parallel_for(
352
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
353
+ [=](sycl::nd_item<3> item_ct1) {
354
+ GGML_UNUSED(item_ct1);
355
+ rope_neox<forward, false>(
356
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
357
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
358
+ theta_scale, freq_factors, row_indices, set_rows_stride);
359
+ });
283
360
  } else {
284
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
285
- rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
286
- theta_scale, freq_factors, item_ct1);
287
- });
361
+ stream->parallel_for(
362
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
363
+ [=](sycl::nd_item<3> item_ct1) {
364
+ GGML_UNUSED(item_ct1);
365
+ rope_neox<forward, true>(
366
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
367
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
368
+ theta_scale, freq_factors, row_indices, set_rows_stride);
369
+ });
288
370
  }
289
371
  }
290
372
 
291
- template <typename T>
292
- static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
293
- const size_t s2, const int n_dims, const int nr, const int32_t * pos,
294
- const float freq_scale, const float freq_base, const float ext_factor,
295
- const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
296
- const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
297
- GGML_ASSERT(ne0 % 2 == 0);
298
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
299
- const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
300
- const sycl::range<3> grid_dims(1, n_blocks_y, nr);
301
- const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
302
-
303
- const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
304
- // Add FP16 capability check if T could be sycl::half
305
- if constexpr (std::is_same_v<T, sycl::half>) {
306
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
307
- }
308
- // launch kernel
373
+ template <bool forward, typename T>
374
+ static void
375
+ rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
376
+ const int ne02, const int s01, const int s02, const int s03,
377
+ const int s1, const int s2, const int s3, const int n_dims,
378
+ const int nr, const int32_t *pos, const float freq_scale,
379
+ const float freq_base, const float ext_factor,
380
+ const float attn_factor, const rope_corr_dims corr_dims,
381
+ const float *freq_factors, const mrope_sections sections,
382
+ const bool is_imrope, dpct::queue_ptr stream) {
383
+ GGML_ASSERT(ne00 % 2 == 0);
384
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
385
+ const int n_blocks_x =
386
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
387
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
388
+
389
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
390
+
309
391
  if (freq_factors == nullptr) {
310
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
311
- rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
312
- corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
313
- });
392
+ stream->parallel_for(
393
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
394
+ [=](sycl::nd_item<3> item_ct1) {
395
+ GGML_UNUSED(item_ct1);
396
+ rope_multi<forward, false, T>(
397
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
398
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
399
+ theta_scale, freq_factors, sections, is_imrope);
400
+ });
314
401
  } else {
315
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
316
- rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
317
- corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
318
- });
402
+ stream->parallel_for(
403
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
404
+ [=](sycl::nd_item<3> item_ct1) {
405
+ GGML_UNUSED(item_ct1);
406
+ rope_multi<forward, true, T>(
407
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
408
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
409
+ theta_scale, freq_factors, sections, is_imrope);
410
+ });
319
411
  }
320
412
  }
321
413
 
414
+ template <bool forward, typename T>
415
+ static void
416
+ rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
417
+ const int ne02, const int s01, const int s02, const int s03,
418
+ const int s1, const int s2, const int s3, const int n_dims,
419
+ const int nr, const int32_t *pos, const float freq_scale,
420
+ const float freq_base, const float ext_factor,
421
+ const float attn_factor, const rope_corr_dims corr_dims,
422
+ const float *freq_factors, const mrope_sections sections,
423
+ dpct::queue_ptr stream) {
424
+ GGML_ASSERT(ne00 % 2 == 0);
425
+ const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
426
+ const int n_blocks_x =
427
+ (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
428
+ const dpct::dim3 block_nums(nr, n_blocks_x, 1);
322
429
 
430
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
323
431
 
324
-
325
- // rope vision
326
- template <typename T>
327
- static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
328
- const size_t s2, const int n_dims, const int nr, const int32_t * pos,
329
- const float freq_scale, const float freq_base, const float ext_factor,
330
- const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
331
- const mrope_sections sections, queue_ptr stream) {
332
- GGML_ASSERT(ne0 % 2 == 0);
333
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
334
- const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
335
- const sycl::range<3> grid_dims(1, n_blocks_y, nr);
336
- const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
337
-
338
- const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
339
- // Add FP16 capability check if T could be sycl::half
340
- if constexpr (std::is_same_v<T, sycl::half>) {
341
- dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
342
- }
343
- // launch kernel
344
432
  if (freq_factors == nullptr) {
345
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
346
- rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
347
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
348
- });
433
+ stream->parallel_for(
434
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
435
+ [=](sycl::nd_item<3> item_ct1) {
436
+ GGML_UNUSED(item_ct1);
437
+ rope_vision<forward, false, T>(
438
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
439
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
440
+ theta_scale, freq_factors, sections);
441
+ });
349
442
  } else {
350
- stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
351
- rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
352
- corr_dims, theta_scale, freq_factors, sections, item_ct1);
353
- });
443
+ stream->parallel_for(
444
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
445
+ [=](sycl::nd_item<3> item_ct1) {
446
+ GGML_UNUSED(item_ct1);
447
+ rope_vision<forward, true, T>(
448
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
449
+ pos, freq_scale, ext_factor, attn_factor, corr_dims,
450
+ theta_scale, freq_factors, sections);
451
+ });
354
452
  }
355
453
  }
356
454
 
357
- inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
455
+ template <bool forward>
456
+ void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
457
+ const ggml_tensor *set_rows = nullptr) {
458
+ const ggml_tensor *src0 = dst->src[0];
459
+ const ggml_tensor *src1 = dst->src[1];
460
+ const ggml_tensor *src2 = dst->src[2];
461
+
462
+ const float *src0_d = (const float *)src0->data;
463
+ const float *src1_d = (const float *)src1->data;
464
+
465
+ void *dst_d = dst->data;
466
+ const int64_t *row_indices = nullptr;
467
+ ggml_type dst_type = dst->type;
468
+ int set_rows_stride = 0;
469
+
470
+ if (set_rows != nullptr) {
471
+ GGML_ASSERT(forward);
472
+ dst_d = set_rows->data;
473
+ row_indices = (const int64_t *)set_rows->src[1]->data;
474
+ dst_type = set_rows->type;
475
+ set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
476
+ }
477
+ dpct::queue_ptr stream = ctx.stream();
478
+
479
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
480
+ GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
481
+ GGML_ASSERT(src0->type == dst->type ||
482
+ (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
358
483
 
359
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
360
- GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
361
- GGML_ASSERT(dst->src[0]->type == dst->type);
362
- const int64_t ne00 = dst->src[0]->ne[0]; // head dims
363
- const int64_t ne01 = dst->src[0]->ne[1]; // num heads
364
- const int64_t ne02 = dst->src[0]->ne[2]; // num heads
365
- const int64_t nr = ggml_nrows(dst->src[0]);
484
+ const int64_t ne00 = src0->ne[0]; // head dims
485
+ const int64_t ne01 = src0->ne[1]; // num heads
486
+ const int64_t ne02 = src0->ne[2]; // num heads
487
+ const int64_t nr = ggml_nrows(src0);
366
488
 
367
- const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
368
- const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
489
+ const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
490
+ const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
491
+ const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
369
492
 
493
+ const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
494
+ const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
495
+ const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
370
496
 
371
- //const int n_past = ((int32_t *) dst->op_params)[0];
372
- const int n_dims = ((int32_t *) dst->op_params)[1];
373
- const int mode = ((int32_t *) dst->op_params)[2];
374
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
375
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
497
+ const int n_dims = ((int32_t *)dst->op_params)[1];
498
+ const int mode = ((int32_t *)dst->op_params)[2];
499
+ const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
376
500
  mrope_sections sections;
377
501
 
378
- // RoPE alteration for extended context
379
502
  float freq_base;
380
503
  float freq_scale;
381
504
  float ext_factor;
@@ -383,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
383
506
  float beta_fast;
384
507
  float beta_slow;
385
508
 
386
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
387
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
388
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
389
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
390
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
391
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
392
- memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
509
+ memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
510
+ memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
511
+ memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
512
+ memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
513
+ memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
514
+ memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
515
+ memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
393
516
 
394
517
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
395
518
  const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
@@ -397,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
397
520
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
398
521
 
399
522
  if (is_mrope) {
400
- GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
523
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
524
+ sections.v[2] > 0);
401
525
  }
402
526
 
403
527
  if (is_vision) {
404
- GGML_ASSERT(n_dims == ne00/2);
528
+ GGML_ASSERT(n_dims == ne00 / 2);
405
529
  }
406
530
 
407
- const int32_t * pos = (const int32_t *) dst->src[1]->data;
531
+ const int32_t *pos = (const int32_t *)src1_d;
408
532
 
409
- const float * freq_factors = nullptr;
410
- if (dst->src[2] != nullptr) {
411
- freq_factors = (const float *) dst->src[2]->data;
533
+ const float *freq_factors = nullptr;
534
+ if (src2 != nullptr) {
535
+ freq_factors = (const float *)src2->data;
412
536
  }
413
537
 
414
538
  rope_corr_dims corr_dims;
415
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
416
-
417
- dpct::queue_ptr main_stream = ctx.stream();
418
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
539
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
540
+ beta_slow, corr_dims.v);
419
541
 
420
542
  // compute
421
543
  if (is_neox) {
422
544
  GGML_SYCL_DEBUG("%s: neox path\n", __func__);
423
- if (dst->src[0]->type == GGML_TYPE_F32) {
424
- rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
425
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
426
- } else if (dst->src[0]->type == GGML_TYPE_F16) {
427
- rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
428
- n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
429
- main_stream);
545
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
546
+ rope_neox_sycl<forward, float, float>(
547
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
548
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
549
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
550
+ set_rows_stride, stream);
551
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
552
+ rope_neox_sycl<forward, float, sycl::half>(
553
+ (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
554
+ s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
555
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
556
+ row_indices, set_rows_stride, stream);
557
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
558
+ rope_neox_sycl<forward, sycl::half, sycl::half>(
559
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
560
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
561
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
562
+ row_indices, set_rows_stride, stream);
430
563
  } else {
431
- GGML_ABORT("fatal error");
564
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
432
565
  }
433
566
  } else if (is_mrope && !is_vision) {
434
567
  GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
435
- if (dst->src[0]->type == GGML_TYPE_F16) {
436
- rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
437
- s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
438
- freq_factors, sections, is_imrope, main_stream);
439
- } else if (dst->src[0]->type == GGML_TYPE_F32) {
440
- rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
441
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
442
- is_imrope, main_stream);
568
+ if (src0->type == GGML_TYPE_F32) {
569
+ rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
570
+ ne00, ne01, ne02, s01, s02, s03, s1, s2,
571
+ s3, n_dims, nr, pos, freq_scale, freq_base,
572
+ ext_factor, attn_factor, corr_dims,
573
+ freq_factors, sections, is_imrope, stream);
574
+ } else if (src0->type == GGML_TYPE_F16) {
575
+ rope_multi_sycl<forward>(
576
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
577
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
578
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
579
+ sections, is_imrope, stream);
443
580
  } else {
444
581
  GGML_ABORT("Fatal error: Tensor type unsupported!");
445
582
  }
446
583
  } else if (is_vision) {
447
584
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
448
- if (dst->src[0]->type == GGML_TYPE_F16) {
449
- rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
450
- s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
451
- freq_factors, sections, main_stream);
452
- } else if (dst->src[0]->type == GGML_TYPE_F32) {
453
- rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
454
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
455
- main_stream);
585
+ if (src0->type == GGML_TYPE_F32) {
586
+ rope_vision_sycl<forward>(
587
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
588
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
589
+ ext_factor, attn_factor, corr_dims, freq_factors, sections,
590
+ stream);
591
+ } else if (src0->type == GGML_TYPE_F16) {
592
+ rope_vision_sycl<forward>(
593
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
594
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
595
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
596
+ sections, stream);
456
597
  } else {
457
598
  GGML_ABORT("Fatal error: Tensor type unsupported!");
458
599
  }
459
600
  } else {
460
601
  GGML_SYCL_DEBUG("%s: norm path\n", __func__);
461
- if (dst->src[0]->type == GGML_TYPE_F32) {
462
- rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
463
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
464
- } else if (dst->src[0]->type == GGML_TYPE_F16) {
465
- rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
466
- n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
467
- main_stream);
602
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
603
+ rope_norm_sycl<forward, float, float>(
604
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
605
+ s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
606
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
607
+ set_rows_stride, stream);
608
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
609
+ rope_norm_sycl<forward, float, sycl::half>(
610
+ (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
611
+ s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
612
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
613
+ row_indices, set_rows_stride, stream);
614
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
615
+ rope_norm_sycl<forward, sycl::half, sycl::half>(
616
+ (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
617
+ ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
618
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
619
+ row_indices, set_rows_stride, stream);
468
620
  } else {
469
- GGML_ABORT("fatal error");
621
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
470
622
  }
471
623
  }
472
624
  }
473
625
 
474
- void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
626
+ void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
475
627
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
476
- ggml_sycl_op_rope(ctx, dst);
628
+
629
+ ggml_sycl_op_rope_impl<true>(ctx, dst);
477
630
  }
478
631
 
632
+ void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
633
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
634
+ ggml_sycl_op_rope_impl<false>(ctx, dst);
635
+ }
636
+
637
+ void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
638
+ ggml_tensor *set_rows) {
639
+ scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
640
+ ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
641
+ }