whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -43,10 +43,15 @@ static __device__ void rope_yarn(
43
43
  template <bool forward, bool has_ff, typename T, typename D>
44
44
  static __global__ void rope_norm(const T * x,
45
45
  D * dst,
46
- const int ne0,
47
- const int ne1,
46
+ const int ne00,
47
+ const int ne01,
48
+ const int ne02,
49
+ const int s01,
50
+ const int s02,
51
+ const int s03,
48
52
  const int s1,
49
53
  const int s2,
54
+ const int s3,
50
55
  const int n_dims,
51
56
  const int32_t * pos,
52
57
  const float freq_scale,
@@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x,
59
64
  const int set_rows_stride) {
60
65
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
61
66
 
62
- if (i0 >= ne0) {
67
+ if (i0 >= ne00) {
63
68
  return;
64
69
  }
65
70
 
66
71
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
67
72
 
68
- const int row_x = row_dst % ne1;
69
- const int channel_x = row_dst / ne1;
70
-
71
- int idst = row_dst * ne0 + i0;
72
- const int ix = channel_x*s2 + row_x*s1 + i0;
73
+ const uint32_t i3 = row_dst / (ne01 * ne02);
74
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
75
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
73
76
 
77
+ int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
78
+ const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
74
79
  // Fusion optimization: ROPE + VIEW + SET_ROWS.
75
80
  // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
76
81
  if (set_rows_stride != 0) {
77
- idst = row_x * ne0 + i0;
78
- idst += row_indices[channel_x] * set_rows_stride;
82
+ idst = i1 * s1 + i0;
83
+ idst += row_indices[i2] * set_rows_stride;
79
84
  }
80
85
 
81
86
  const auto & store_coaelsced = [&](float x0, float x1) {
@@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
92
97
  return;
93
98
  }
94
99
 
95
- const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
100
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
96
101
 
97
102
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
98
103
 
@@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x,
110
115
  template <bool forward, bool has_ff, typename T, typename D>
111
116
  static __global__ void rope_neox(const T * x,
112
117
  D * dst,
113
- const int ne0,
114
- const int ne1,
118
+ const int ne00,
119
+ const int ne01,
120
+ const int ne02,
121
+ const int s01,
122
+ const int s02,
123
+ const int s03,
115
124
  const int s1,
116
125
  const int s2,
126
+ const int s3,
117
127
  const int n_dims,
118
128
  const int32_t * pos,
119
129
  const float freq_scale,
@@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x,
126
136
  const int set_rows_stride) {
127
137
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
128
138
 
129
- if (i0 >= ne0) {
139
+ if (i0 >= ne00) {
130
140
  return;
131
141
  }
132
142
 
133
143
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
134
144
 
135
- const int row_x = row_dst % ne1;
136
- const int channel_x = row_dst / ne1;
145
+ const uint32_t i3 = row_dst / (ne01 * ne02);
146
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
147
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
137
148
 
138
- int idst = row_dst * ne0 + i0 / 2;
139
- const int ix = channel_x*s2 + row_x*s1 + i0/2;
149
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
150
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
140
151
 
141
152
  // Fusion optimization: ROPE + VIEW + SET_ROWS.
142
153
  // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
143
154
  if (set_rows_stride != 0) {
144
- idst = row_x * ne0 + i0 / 2;
145
- idst += row_indices[channel_x] * set_rows_stride;
155
+ idst = i1 * s1 + i0 / 2;
156
+ idst += row_indices[i2] * set_rows_stride;
146
157
  }
147
158
 
148
159
  if (i0 >= n_dims) {
@@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x,
152
163
  return;
153
164
  }
154
165
 
155
- const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
166
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
156
167
 
157
168
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
158
169
 
@@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x,
168
179
  dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
169
180
  }
170
181
 
171
- template<bool forward, bool has_ff, typename T>
172
- static __global__ void rope_multi(
173
- const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
174
- const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
175
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
176
- const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
177
-
178
- if (i0 >= ne0) {
182
+ template <bool forward, bool has_ff, typename T>
183
+ static __global__ void rope_multi(const T * x,
184
+ T * dst,
185
+ const int ne00,
186
+ const int ne01,
187
+ const int ne02,
188
+ const int s01,
189
+ const int s02,
190
+ const int s03,
191
+ const int s1,
192
+ const int s2,
193
+ const int s3,
194
+ const int n_dims,
195
+ const int32_t * pos,
196
+ const float freq_scale,
197
+ const float ext_factor,
198
+ const float attn_factor,
199
+ const rope_corr_dims corr_dims,
200
+ const float theta_scale,
201
+ const float * freq_factors,
202
+ const mrope_sections sections,
203
+ const bool is_imrope) {
204
+ const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
205
+
206
+ if (i0 >= ne00) {
179
207
  return;
180
208
  }
181
209
 
182
210
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
183
211
 
184
- const int row_x = row_dst % ne1;
185
- const int channel_x = row_dst / ne1;
212
+ const uint32_t i3 = row_dst / (ne01 * ne02);
213
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
214
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
186
215
 
187
- const int idst = row_dst*ne0 + i0/2;
188
- const int ix = channel_x*s2 + row_x*s1 + i0/2;
216
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
217
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
189
218
 
190
219
  if (i0 >= n_dims) {
191
220
  dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
@@ -200,27 +229,24 @@ static __global__ void rope_multi(
200
229
 
201
230
  float theta_base = 0.0;
202
231
  if (is_imrope) {
203
- if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
204
- theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
205
- } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
206
- theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
207
- } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
208
- theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
232
+ if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
233
+ theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
234
+ } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
235
+ theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
236
+ } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
237
+ theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
209
238
  } else {
210
- theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
239
+ theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
211
240
  }
212
241
  } else {
213
242
  if (sector < sections.v[0]) {
214
- theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
215
- }
216
- else if (sector >= sections.v[0] && sector < sec_w) {
217
- theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
218
- }
219
- else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
220
- theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
221
- }
222
- else if (sector >= sec_w + sections.v[2]) {
223
- theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
243
+ theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
244
+ } else if (sector >= sections.v[0] && sector < sec_w) {
245
+ theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
246
+ } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
247
+ theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
248
+ } else if (sector >= sec_w + sections.v[2]) {
249
+ theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
224
250
  }
225
251
  }
226
252
 
@@ -238,37 +264,53 @@ static __global__ void rope_multi(
238
264
  dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
239
265
  }
240
266
 
241
- template<bool forward, bool has_ff, typename T>
242
- static __global__ void rope_vision(
243
- const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
244
- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
245
- const float theta_scale, const float * freq_factors, const mrope_sections sections) {
267
+ template <bool forward, bool has_ff, typename T>
268
+ static __global__ void rope_vision(const T * x,
269
+ T * dst,
270
+ const int ne00,
271
+ const int ne01,
272
+ const int ne02,
273
+ const int s01,
274
+ const int s02,
275
+ const int s03,
276
+ const int s1,
277
+ const int s2,
278
+ const int s3,
279
+ const int n_dims,
280
+ const int32_t * pos,
281
+ const float freq_scale,
282
+ const float ext_factor,
283
+ const float attn_factor,
284
+ const rope_corr_dims corr_dims,
285
+ const float theta_scale,
286
+ const float * freq_factors,
287
+ const mrope_sections sections) {
246
288
  const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
247
289
 
248
- if (i0 >= ne0) {
290
+ if (i0 >= ne00) {
249
291
  return;
250
292
  }
251
293
 
252
294
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
253
295
 
254
- const int row_x = row_dst % ne1;
255
- const int channel_x = row_dst / ne1;
296
+ const uint32_t i3 = row_dst / (ne01 * ne02);
297
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
298
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
256
299
 
257
- const int idst = row_dst*ne0 + i0/2;
258
- const int ix = channel_x*s2 + row_x*s1 + i0/2;
300
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
301
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
259
302
 
260
303
  const int sect_dims = sections.v[0] + sections.v[1];
261
- const int sec_w = sections.v[1] + sections.v[0];
262
- const int sector = (i0 / 2) % sect_dims;
304
+ const int sec_w = sections.v[1] + sections.v[0];
305
+ const int sector = (i0 / 2) % sect_dims;
263
306
 
264
307
  float theta_base = 0.0;
265
308
  if (sector < sections.v[0]) {
266
309
  const int p = sector;
267
- theta_base = pos[channel_x]*powf(theta_scale, p);
268
- }
269
- else if (sector >= sections.v[0] && sector < sec_w) {
310
+ theta_base = pos[i2] * powf(theta_scale, p);
311
+ } else if (sector >= sections.v[0] && sector < sec_w) {
270
312
  const int p = sector - sections.v[0];
271
- theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
313
+ theta_base = pos[i2 + ne02] * powf(theta_scale, p);
272
314
  }
273
315
 
274
316
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -288,10 +330,15 @@ static __global__ void rope_vision(
288
330
  template <bool forward, typename T, typename D>
289
331
  static void rope_norm_cuda(const T * x,
290
332
  D * dst,
291
- const int ne0,
292
- const int ne1,
333
+ const int ne00,
334
+ const int ne01,
335
+ const int ne02,
336
+ const int s01,
337
+ const int s02,
338
+ const int s03,
293
339
  const int s1,
294
340
  const int s2,
341
+ const int s3,
295
342
  const int n_dims,
296
343
  const int nr,
297
344
  const int32_t * pos,
@@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x,
304
351
  const int64_t * row_indices,
305
352
  const int set_rows_stride,
306
353
  cudaStream_t stream) {
307
- GGML_ASSERT(ne0 % 2 == 0);
354
+ GGML_ASSERT(ne00 % 2 == 0);
308
355
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
309
- const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
356
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
310
357
  const dim3 block_nums(nr, n_blocks_x, 1);
311
358
 
312
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
359
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
313
360
 
314
361
  if (freq_factors == nullptr) {
315
362
  rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
316
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
317
- freq_factors, row_indices, set_rows_stride);
363
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
364
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
318
365
  } else {
319
366
  rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
320
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
321
- freq_factors, row_indices, set_rows_stride);
367
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
368
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
322
369
  }
323
370
  }
324
371
 
325
372
  template <bool forward, typename T, typename D>
326
373
  static void rope_neox_cuda(const T * x,
327
374
  D * dst,
328
- const int ne0,
329
- const int ne1,
375
+ const int ne00,
376
+ const int ne01,
377
+ const int ne02,
378
+ const int s01,
379
+ const int s02,
380
+ const int s03,
330
381
  const int s1,
331
382
  const int s2,
383
+ const int s3,
332
384
  const int n_dims,
333
385
  const int nr,
334
386
  const int32_t * pos,
@@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x,
341
393
  const int64_t * row_indices,
342
394
  const int set_rows_stride,
343
395
  cudaStream_t stream) {
344
- GGML_ASSERT(ne0 % 2 == 0);
396
+ GGML_ASSERT(ne00 % 2 == 0);
345
397
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
346
- const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
398
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
347
399
  const dim3 block_nums(nr, n_blocks_x, 1);
348
400
 
349
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
401
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
350
402
 
351
403
  if (freq_factors == nullptr) {
352
404
  rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
353
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
354
- freq_factors, row_indices, set_rows_stride);
405
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
406
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
355
407
  } else {
356
408
  rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
357
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
358
- freq_factors, row_indices, set_rows_stride);
409
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
410
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
359
411
  }
360
412
  }
361
413
 
362
- template<bool forward, typename T>
363
- static void rope_multi_cuda(
364
- const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
365
- const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
366
- const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
367
- GGML_ASSERT(ne0 % 2 == 0);
414
+ template <bool forward, typename T>
415
+ static void rope_multi_cuda(const T * x,
416
+ T * dst,
417
+ const int ne00,
418
+ const int ne01,
419
+ const int ne02,
420
+ const int s01,
421
+ const int s02,
422
+ const int s03,
423
+ const int s1,
424
+ const int s2,
425
+ const int s3,
426
+ const int n_dims,
427
+ const int nr,
428
+ const int32_t * pos,
429
+ const float freq_scale,
430
+ const float freq_base,
431
+ const float ext_factor,
432
+ const float attn_factor,
433
+ const rope_corr_dims corr_dims,
434
+ const float * freq_factors,
435
+ const mrope_sections sections,
436
+ const bool is_imrope,
437
+ cudaStream_t stream) {
438
+ GGML_ASSERT(ne00 % 2 == 0);
368
439
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
369
- const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
440
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
370
441
  const dim3 block_nums(nr, n_blocks_x, 1);
371
442
 
372
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
443
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
373
444
 
374
445
  if (freq_factors == nullptr) {
375
446
  rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
376
- x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
447
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
377
448
  attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
378
449
  } else {
379
450
  rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
380
- x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
451
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
381
452
  attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
382
453
  }
383
454
  }
384
455
 
385
- template<bool forward, typename T>
386
- static void rope_vision_cuda(
387
- const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
388
- const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
389
- const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
390
- GGML_ASSERT(ne0 % 2 == 0);
456
+ template <bool forward, typename T>
457
+ static void rope_vision_cuda(const T * x,
458
+ T * dst,
459
+ const int ne00,
460
+ const int ne01,
461
+ const int ne02,
462
+ const int s01,
463
+ const int s02,
464
+ const int s03,
465
+ const int s1,
466
+ const int s2,
467
+ const int s3,
468
+ const int n_dims,
469
+ const int nr,
470
+ const int32_t * pos,
471
+ const float freq_scale,
472
+ const float freq_base,
473
+ const float ext_factor,
474
+ const float attn_factor,
475
+ const rope_corr_dims corr_dims,
476
+ const float * freq_factors,
477
+ const mrope_sections sections,
478
+ cudaStream_t stream) {
479
+ GGML_ASSERT(ne00 % 2 == 0);
391
480
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
392
- const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
481
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
393
482
  const dim3 block_nums(nr, n_blocks_x, 1);
394
483
  // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
395
484
  // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
@@ -398,11 +487,11 @@ static void rope_vision_cuda(
398
487
 
399
488
  if (freq_factors == nullptr) {
400
489
  rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
401
- x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
490
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
402
491
  attn_factor, corr_dims, theta_scale, freq_factors, sections);
403
492
  } else {
404
493
  rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
405
- x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
494
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
406
495
  attn_factor, corr_dims, theta_scale, freq_factors, sections);
407
496
  }
408
497
  }
@@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
445
534
 
446
535
  const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
447
536
  const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
537
+ const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
538
+
539
+ const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
540
+ const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
541
+ const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
448
542
 
449
543
  //const int n_past = ((int32_t *) dst->op_params)[0];
450
544
  const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
495
589
  // compute
496
590
  if (is_neox) {
497
591
  if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
498
- rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
499
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
500
- freq_factors, row_indices, set_rows_stride, stream);
592
+ rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
593
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
594
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
595
+ set_rows_stride, stream);
501
596
  } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
502
- rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
503
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
504
- freq_factors, row_indices, set_rows_stride, stream);
597
+ rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
598
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
599
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
600
+ set_rows_stride, stream);
505
601
  } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
506
- rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
507
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
508
- freq_factors, row_indices, set_rows_stride, stream);
602
+ rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
603
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
604
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
605
+ set_rows_stride, stream);
509
606
  } else {
510
607
  GGML_ABORT("fatal error");
511
608
  }
512
609
  } else if (is_mrope && !is_vision) {
513
610
  if (src0->type == GGML_TYPE_F32) {
514
- rope_multi_cuda<forward>(
515
- (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
516
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
611
+ rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
612
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
613
+ corr_dims, freq_factors, sections, is_imrope, stream);
517
614
  } else if (src0->type == GGML_TYPE_F16) {
518
- rope_multi_cuda<forward>(
519
- (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
520
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
615
+ rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
616
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
617
+ corr_dims, freq_factors, sections, is_imrope, stream);
521
618
  } else {
522
619
  GGML_ABORT("fatal error");
523
620
  }
524
621
  } else if (is_vision) {
525
622
  if (src0->type == GGML_TYPE_F32) {
526
- rope_vision_cuda<forward>(
527
- (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
528
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
623
+ rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
624
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
625
+ corr_dims, freq_factors, sections, stream);
529
626
  } else if (src0->type == GGML_TYPE_F16) {
530
- rope_vision_cuda<forward>(
531
- (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
532
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
627
+ rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
628
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
629
+ corr_dims, freq_factors, sections, stream);
533
630
  } else {
534
631
  GGML_ABORT("fatal error");
535
632
  }
536
633
  } else {
537
634
  if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
538
- rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
539
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
540
- freq_factors, row_indices, set_rows_stride, stream);
635
+ rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
636
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
637
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
638
+ set_rows_stride, stream);
541
639
  } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
542
- rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
543
- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
544
- freq_factors, row_indices, set_rows_stride, stream);
640
+ rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
641
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
642
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
643
+ set_rows_stride, stream);
545
644
  } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
546
- rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
547
- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
548
- freq_factors, row_indices, set_rows_stride, stream);
645
+ rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
646
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
647
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
648
+ set_rows_stride, stream);
549
649
  } else {
550
650
  GGML_ABORT("fatal error");
551
651
  }