whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache(
97
97
  __func__, hparams.n_embd_v_gqa_max());
98
98
  }
99
99
 
100
+ const bool is_mla = hparams.is_mla();
101
+
100
102
  for (uint32_t il = 0; il < hparams.n_layer; il++) {
101
103
  if (!hparams.has_kv(il)) {
102
104
  LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
@@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache(
130
132
  throw std::runtime_error("failed to create ggml context for kv cache");
131
133
  }
132
134
 
133
- ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
134
- ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
135
+ const bool has_k = true;
136
+ const bool has_v = !is_mla;
137
+
138
+ ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
139
+ ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
135
140
 
136
- ggml_format_name(k, "cache_k_l%d", il);
137
- ggml_format_name(v, "cache_v_l%d", il);
141
+ has_k && ggml_format_name(k, "cache_k_l%d", il);
142
+ has_v && ggml_format_name(v, "cache_v_l%d", il);
138
143
 
139
144
  std::vector<ggml_tensor *> k_stream;
140
145
  std::vector<ggml_tensor *> v_stream;
141
146
 
142
147
  for (uint32_t s = 0; s < n_stream; ++s) {
143
- k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
144
- v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
148
+ k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
149
+ v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
145
150
  }
146
151
 
147
152
  map_layer_ids[il] = layers.size();
@@ -578,7 +583,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
578
583
  break;
579
584
  }
580
585
 
581
- // remeber the position that we found
586
+ // remember the position that we found
582
587
  res.push_back(sinfo_new);
583
588
 
584
589
  // store the old state of the cells in the recovery stack
@@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
647
652
  const auto & layer = layers[il];
648
653
 
649
654
  ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
650
- ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
655
+
656
+ if (layer.v_stream[ssrc]) {
657
+ ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
658
+ }
651
659
  }
652
660
  }
653
661
  }
@@ -852,7 +860,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
852
860
  const llama_seq_id seq_id_cell = cells.seq_get(idx);
853
861
 
854
862
  // SWA mask
855
- if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
863
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
856
864
  can_use = true;
857
865
  }
858
866
  }
@@ -966,6 +974,13 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
966
974
  }
967
975
 
968
976
  bool llama_kv_cache::get_can_shift() const {
977
+ // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot.
978
+ if (model.arch == LLM_ARCH_STEP35) {
979
+ return false;
980
+ }
981
+ if (hparams.n_pos_per_embd() > 1) {
982
+ return false;
983
+ }
969
984
  return true;
970
985
  }
971
986
 
@@ -1018,8 +1033,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
1018
1033
  const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1019
1034
 
1020
1035
  return ggml_view_4d(ctx, k,
1021
- hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
1022
- ggml_row_size(k->type, hparams.n_embd_head_k),
1036
+ hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
1037
+ ggml_row_size(k->type, hparams.n_embd_head_k(il)),
1023
1038
  ggml_row_size(k->type, n_embd_k_gqa),
1024
1039
  ggml_row_size(k->type, n_embd_k_gqa*kv_size),
1025
1040
  ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
@@ -1041,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
1041
1056
  if (!v_trans) {
1042
1057
  // note: v->nb[1] <= v->nb[2]
1043
1058
  return ggml_view_4d(ctx, v,
1044
- hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1045
- ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1059
+ hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns,
1060
+ ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1]
1046
1061
  ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1047
1062
  ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
1048
1063
  ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
@@ -1050,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
1050
1065
 
1051
1066
  // note: v->nb[1] > v->nb[2]
1052
1067
  return ggml_view_4d(ctx, v,
1053
- n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1054
- ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1068
+ n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns,
1069
+ ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1]
1055
1070
  ggml_row_size(v->type, kv_size), // v->nb[2]
1056
1071
  ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
1057
1072
  ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
@@ -1237,90 +1252,236 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
1237
1252
  }
1238
1253
  }
1239
1254
 
1240
- void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1241
- const uint32_t n_tokens = ubatch->n_tokens;
1255
+ struct args_set_input_kq_mask {
1256
+ const llama_hparams & hparams;
1257
+ const llama_ubatch * ubatch;
1242
1258
 
1243
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1244
- float * data = (float *) dst->data;
1259
+ const std::vector<llama_kv_cells> & v_cells;
1260
+ const std::vector<uint32_t> & seq_to_stream;
1245
1261
 
1246
- const int64_t n_kv = dst->ne[0];
1247
- const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1262
+ uint32_t n_swa;
1263
+ llama_swa_type swa_type;
1248
1264
 
1249
- GGML_ASSERT(n_tokens%n_stream == 0);
1265
+ int64_t n_kv;
1266
+ int64_t n_stream;
1267
+ int64_t n_tps;
1268
+ };
1250
1269
 
1251
- // n_tps == n_tokens_per_stream
1252
- const int64_t n_tps = n_tokens/n_stream;
1270
+ template<bool causal, bool swa, bool is_2d, bool alibi>
1271
+ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1272
+ //const auto & hparams = args.hparams;
1273
+ const auto & ubatch = args.ubatch;
1253
1274
 
1254
- std::fill(data, data + ggml_nelements(dst), -INFINITY);
1255
-
1256
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
1257
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
1258
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
1259
- // Causal mask:
1260
- // xxx-------
1261
- // xxxx------
1262
- // xxxxx-----
1263
- // Non-causal mask:
1264
- // xxxxx-----
1265
- // xxxxx-----
1266
- // xxxxx-----
1267
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
1268
- // TODO: optimize this section
1269
- for (uint32_t h = 0; h < 1; ++h) {
1270
- for (uint32_t s = 0; s < n_stream; ++s) {
1271
- for (uint32_t ii = 0; ii < n_tps; ++ii) {
1272
- const uint32_t i = s*n_tps + ii;
1275
+ const auto & v_cells = args.v_cells;
1276
+ const auto & seq_to_stream = args.seq_to_stream;
1273
1277
 
1274
- const llama_seq_id seq_id = ubatch->seq_id[i][0];
1278
+ const uint32_t n_swa = args.n_swa;
1279
+ const llama_swa_type swa_type = args.swa_type;
1275
1280
 
1276
- const auto & cells = v_cells[seq_to_stream[seq_id]];
1281
+ const int64_t n_kv = args.n_kv;
1282
+ const int64_t n_stream = args.n_stream;
1283
+ const int64_t n_tps = args.n_tps;
1277
1284
 
1278
- const llama_pos p1 = ubatch->pos[i];
1285
+ // the min position in the batch for each sequence
1286
+ llama_pos seq_pos_min[LLAMA_MAX_SEQ];
1287
+ std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
1279
1288
 
1280
- // for M-RoPE
1281
- const bool is_2d = ubatch->is_pos_2d();
1282
- const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1283
- const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
1289
+ for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
1290
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
1284
1291
 
1285
- const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
1292
+ seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
1293
+ }
1286
1294
 
1287
- for (uint32_t j = 0; j < n_kv; ++j) {
1288
- if (cells.is_empty(j)) {
1289
- continue;
1290
- }
1295
+ for (uint32_t s = 0; s < n_stream; ++s) {
1296
+ // bookkeeping of the KQ mask cells that could change for other tokens of the same sequence
1297
+ std::unordered_map<llama_seq_id, uint32_t> seq_srct;
1298
+ std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
1299
+
1300
+ for (uint32_t ii = 0; ii < n_tps; ++ii) {
1301
+ const uint32_t i = s*n_tps + ii;
1302
+
1303
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
1304
+
1305
+ const auto & cells = v_cells.at(seq_to_stream[seq_id]);
1306
+
1307
+ llama_pos p0 = -1;
1308
+ const llama_pos p1 = ubatch->pos[i];
1309
+
1310
+ // for M-RoPE
1311
+ const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1312
+ const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
1313
+
1314
+ const uint64_t idst = n_kv*i;
1291
1315
 
1292
- // mask the token if not the same sequence
1293
- if (!cells.seq_has(j, seq_id)) {
1294
- continue;
1316
+ // for tokens of the same sequence, the mask is mostly the same, so we can reuse it
1317
+ // the only cells that could change are the ones that are with similar positions as the
1318
+ // ones in the batch (i.e. due to causal masking, SWA, etc.)
1319
+ // keep track of those cells and shortcut the loop to save time
1320
+ // note: this optimization is not compatible with Alibi position encoding
1321
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18842
1322
+ bool prev = false;
1323
+
1324
+ auto & idxs = seq_idxs[seq_id];
1325
+
1326
+ if (!alibi) {
1327
+ if (seq_srct.find(seq_id) != seq_srct.end()) {
1328
+ const uint32_t srct = seq_srct[seq_id];
1329
+
1330
+ const uint64_t idst_prev = n_kv*srct;
1331
+
1332
+ std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
1333
+
1334
+ prev = true;
1335
+ } else {
1336
+ idxs.clear();
1337
+ idxs.reserve(ubatch->n_tokens + n_swa + 32);
1338
+
1339
+ seq_srct[seq_id] = i;
1340
+ }
1341
+ }
1342
+
1343
+ for (uint32_t jj = 0; jj < n_kv; ++jj) {
1344
+ uint32_t j = jj;
1345
+
1346
+ // we have an exiting mask for this sequence -> update just seq_idxs
1347
+ if (!alibi) {
1348
+ if (prev) {
1349
+ if (jj >= idxs.size()) {
1350
+ break;
1351
+ }
1352
+
1353
+ j = idxs[jj];
1295
1354
  }
1355
+ }
1296
1356
 
1297
- const llama_pos p0 = cells.pos_get(j);
1357
+ if (cells.is_empty(j)) {
1358
+ goto skip;
1359
+ }
1360
+
1361
+ // mask the token if not the same sequence
1362
+ if (!cells.seq_has(j, seq_id)) {
1363
+ goto skip;
1364
+ }
1365
+
1366
+ p0 = cells.pos_get(j);
1298
1367
 
1368
+ if (!alibi) {
1369
+ if (!prev) {
1370
+ // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
1371
+ if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
1372
+ idxs.push_back(j);
1373
+ }
1374
+ }
1375
+ }
1376
+
1377
+ if (causal) {
1299
1378
  // mask future tokens
1300
- if (causal_attn && p0 > p1) {
1301
- continue;
1379
+ if (p0 > p1) {
1380
+ goto skip;
1302
1381
  }
1303
1382
 
1304
1383
  // M-RoPE causal mask
1305
- if (causal_attn && is_2d && p0 == p1) {
1306
- const auto & p0_ext = cells.ext_get(j);
1307
- if (p0_ext.is_2d_gt(p1_x, p1_y)) {
1308
- continue;
1384
+ if (is_2d) {
1385
+ if (p0 == p1) {
1386
+ const auto & p0_ext = cells.ext_get(j);
1387
+
1388
+ if (p0_ext.is_2d_gt(p1_x, p1_y)) {
1389
+ goto skip;
1390
+ }
1309
1391
  }
1310
1392
  }
1393
+ }
1311
1394
 
1312
- // apply SWA if any
1313
- if (is_masked_swa(p0, p1)) {
1314
- continue;
1395
+ // apply SWA if any
1396
+ if (swa) {
1397
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
1398
+ goto skip;
1315
1399
  }
1400
+ }
1316
1401
 
1317
- data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
1402
+ if (alibi) {
1403
+ data[idst + j] = -std::abs(p0 - p1);
1404
+ } else {
1405
+ data[idst + j] = 0.0f;
1318
1406
  }
1407
+
1408
+ continue;
1409
+ skip:
1410
+ data[idst + j] = -INFINITY;
1319
1411
  }
1320
1412
  }
1321
1413
  }
1322
1414
  }
1323
1415
 
1416
+ template<bool causal, bool swa, bool is_2d>
1417
+ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1418
+ const bool alibi = args.hparams.use_alibi;
1419
+ if (alibi) {
1420
+ set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
1421
+ } else {
1422
+ set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
1423
+ }
1424
+ }
1425
+
1426
+ template<bool causal, bool swa>
1427
+ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1428
+ const bool is_2d = args.ubatch->is_pos_2d();
1429
+ if (is_2d) {
1430
+ set_input_kq_mask_impl<causal, swa, true> (args, data);
1431
+ } else {
1432
+ set_input_kq_mask_impl<causal, swa, false>(args, data);
1433
+ }
1434
+ }
1435
+
1436
+ template<bool causal>
1437
+ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
1438
+ const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
1439
+ if (swa) {
1440
+ set_input_kq_mask_impl<causal, true> (args, data);
1441
+ } else {
1442
+ set_input_kq_mask_impl<causal, false>(args, data);
1443
+ }
1444
+ }
1445
+
1446
+ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1447
+ const uint32_t n_tokens = ubatch->n_tokens;
1448
+
1449
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1450
+ float * data = (float *) dst->data;
1451
+
1452
+ const int64_t n_kv = dst->ne[0];
1453
+ const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1454
+
1455
+ GGML_ASSERT(n_tokens%n_stream == 0);
1456
+
1457
+ // n_tps == n_tokens_per_stream
1458
+ const int64_t n_tps = n_tokens/n_stream;
1459
+
1460
+ //const int64_t t_start = ggml_time_us();
1461
+
1462
+ const args_set_input_kq_mask args = {
1463
+ /*.hparams =*/ hparams,
1464
+ /*.ubatch =*/ ubatch,
1465
+ /*.v_cells =*/ v_cells,
1466
+ /*.seq_to_stream =*/ seq_to_stream,
1467
+ /*.n_swa =*/ n_swa,
1468
+ /*.swa_type =*/ swa_type,
1469
+ /*.n_kv =*/ n_kv,
1470
+ /*.n_stream =*/ n_stream,
1471
+ /*.n_tps =*/ n_tps,
1472
+ };
1473
+
1474
+ if (causal_attn) {
1475
+ set_input_kq_mask_impl<true> (args, data);
1476
+ } else {
1477
+ set_input_kq_mask_impl<false>(args, data);
1478
+ }
1479
+
1480
+ //const int64_t t_end = ggml_time_us();
1481
+
1482
+ //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
1483
+ }
1484
+
1324
1485
  void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1325
1486
  const int64_t n_tokens = ubatch->n_tokens;
1326
1487
 
@@ -1370,7 +1531,7 @@ size_t llama_kv_cache::size_v_bytes() const {
1370
1531
  size_t size_v_bytes = 0;
1371
1532
 
1372
1533
  for (const auto & layer : layers) {
1373
- size_v_bytes += ggml_nbytes(layer.v);
1534
+ size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
1374
1535
  }
1375
1536
 
1376
1537
  return size_v_bytes;
@@ -1383,7 +1544,8 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
1383
1544
  ggml_tensor * shift,
1384
1545
  ggml_tensor * factors,
1385
1546
  float freq_base,
1386
- float freq_scale) const {
1547
+ float freq_scale,
1548
+ uint32_t il) const {
1387
1549
  const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
1388
1550
 
1389
1551
  const auto & yarn_ext_factor = cparams.yarn_ext_factor;
@@ -1391,7 +1553,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
1391
1553
  const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1392
1554
  const auto & yarn_attn_factor = cparams.yarn_attn_factor;
1393
1555
 
1394
- const auto & n_rot = hparams.n_rot;
1556
+ const auto & n_rot = hparams.n_rot(il);
1395
1557
  const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
1396
1558
  // @ngxson : this is a workaround
1397
1559
  // for M-RoPE, we want to rotate the whole vector when doing KV shift
@@ -1445,9 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
1445
1607
  auto * ctx = res->get_ctx();
1446
1608
  auto * gf = res->get_gf();
1447
1609
 
1448
- const auto & n_embd_head_k = hparams.n_embd_head_k;
1449
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
1450
-
1451
1610
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
1452
1611
 
1453
1612
  inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
@@ -1461,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
1461
1620
  const int64_t n_head_kv = hparams.n_head_kv(il);
1462
1621
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1463
1622
 
1623
+ const auto n_rot = hparams.n_rot(il);
1624
+ const auto n_embd_head_k = hparams.n_embd_head_k(il);
1625
+ const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
1626
+
1464
1627
  const float freq_base_l = model.get_rope_freq_base (cparams, il);
1465
1628
  const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
1466
1629
 
@@ -1468,12 +1631,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
1468
1631
 
1469
1632
  ggml_tensor * k =
1470
1633
  ggml_view_3d(ctx, layer.k,
1471
- n_embd_head_k, n_head_kv, get_size()*n_stream,
1634
+ n_rot, n_head_kv, get_size()*n_stream,
1472
1635
  ggml_row_size(layer.k->type, n_embd_head_k),
1473
1636
  ggml_row_size(layer.k->type, n_embd_k_gqa),
1474
- 0);
1637
+ ggml_row_size(layer.k->type, n_embd_nope));
1475
1638
 
1476
- ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
1639
+ ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
1477
1640
 
1478
1641
  ggml_build_forward_expand(gf, cur);
1479
1642
  }
@@ -1483,10 +1646,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
1483
1646
  return gf;
1484
1647
  }
1485
1648
 
1486
- bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
1487
- return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
1488
- }
1489
-
1490
1649
  void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
1491
1650
  GGML_UNUSED(flags);
1492
1651
 
@@ -1599,8 +1758,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
1599
1758
  io.write(&pos, sizeof(pos));
1600
1759
  io.write(&n_seq_id, sizeof(n_seq_id));
1601
1760
 
1602
- // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
1603
- // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1761
+ if (hparams.n_pos_per_embd() > 1) {
1762
+ const llama_kv_cell_ext ext = cells.ext_get(i);
1763
+ io.write(&ext, sizeof(ext));
1764
+ }
1604
1765
 
1605
1766
  for (const auto & seq_id : seq_ids) {
1606
1767
  io.write(&seq_id, sizeof(seq_id));
@@ -1618,8 +1779,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
1618
1779
  io.write(&v_trans, sizeof(v_trans));
1619
1780
  io.write(&n_layer, sizeof(n_layer));
1620
1781
 
1621
- std::vector<uint8_t> tmp_buf;
1622
-
1623
1782
  // Iterate and write all the keys first, each row is a cell
1624
1783
  // Get whole range at a time
1625
1784
  for (const auto & layer : layers) {
@@ -1637,7 +1796,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
1637
1796
  const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1638
1797
  io.write(&k_size_row, sizeof(k_size_row));
1639
1798
 
1640
- // Read each range of cells of k_size length each into tmp_buf and write out
1799
+ // Read each range of cells of k_size length and write out
1641
1800
  for (const auto & range : cr.data) {
1642
1801
  const size_t range_size = range.second - range.first;
1643
1802
  const size_t buf_size = range_size * k_size_row;
@@ -1652,6 +1811,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
1652
1811
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1653
1812
 
1654
1813
  auto * v = layer.v_stream[cr.strm];
1814
+ if (!v) {
1815
+ continue;
1816
+ }
1655
1817
 
1656
1818
  // Write value type
1657
1819
  const int32_t v_type_i = (int32_t) v->type;
@@ -1661,7 +1823,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
1661
1823
  const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1662
1824
  io.write(&v_size_row, sizeof(v_size_row));
1663
1825
 
1664
- // Read each range of cells of v_size length each into tmp_buf and write out
1826
+ // Read each range of cells of v_size length and write out
1665
1827
  for (const auto & range : cr.data) {
1666
1828
  const size_t range_size = range.second - range.first;
1667
1829
  const size_t buf_size = range_size * v_size_row;
@@ -1678,6 +1840,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
1678
1840
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1679
1841
 
1680
1842
  auto * v = layer.v_stream[cr.strm];
1843
+ if (!v) {
1844
+ continue;
1845
+ }
1681
1846
 
1682
1847
  // Write value type
1683
1848
  const int32_t v_type_i = (int32_t) v->type;
@@ -1692,7 +1857,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
1692
1857
 
1693
1858
  // For each row, we get the element values of each cell
1694
1859
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1695
- // Read each range of cells of v_size_el length each into tmp_buf and write out
1860
+ // Read each range of cells of v_size_el length and write out
1696
1861
  for (const auto & range : cr.data) {
1697
1862
  const size_t range_size = range.second - range.first;
1698
1863
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
@@ -1730,6 +1895,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
1730
1895
  return false;
1731
1896
  }
1732
1897
 
1898
+ if (hparams.n_pos_per_embd() > 1) {
1899
+ llama_kv_cell_ext ext;
1900
+ io.read_to(&ext, sizeof(ext));
1901
+
1902
+ ubatch.pos[i + ubatch.n_tokens] = ext.y;
1903
+ ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
1904
+ }
1905
+
1733
1906
  // read the sequence id, but directly discard it - we will use dest_seq_id instead
1734
1907
  {
1735
1908
  llama_seq_id seq_id;
@@ -1780,6 +1953,12 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
1780
1953
 
1781
1954
  cells.pos_set(i, pos);
1782
1955
 
1956
+ if (hparams.n_pos_per_embd() > 1) {
1957
+ llama_kv_cell_ext ext;
1958
+ io.read_to(&ext, sizeof(ext));
1959
+ cells.ext_set(i, ext);
1960
+ }
1961
+
1783
1962
  for (uint32_t j = 0; j < n_seq_id; ++j) {
1784
1963
  llama_seq_id seq_id;
1785
1964
  io.read_to(&seq_id, sizeof(seq_id));
@@ -1881,6 +2060,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
1881
2060
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1882
2061
 
1883
2062
  auto * v = layer.v_stream[strm];
2063
+ if (!v) {
2064
+ continue;
2065
+ }
1884
2066
 
1885
2067
  // Read type of value
1886
2068
  int32_t v_type_i_ref;
@@ -1922,6 +2104,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
1922
2104
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1923
2105
 
1924
2106
  auto * v = layer.v_stream[strm];
2107
+ if (!v) {
2108
+ continue;
2109
+ }
1925
2110
 
1926
2111
  // Read type of value
1927
2112
  int32_t v_type_i_ref;
@@ -257,8 +257,6 @@ private:
257
257
  size_t size_k_bytes() const;
258
258
  size_t size_v_bytes() const;
259
259
 
260
- bool is_masked_swa(llama_pos p0, llama_pos p1) const;
261
-
262
260
  ggml_tensor * build_rope_shift(
263
261
  const llama_cparams & cparams,
264
262
  ggml_context * ctx,
@@ -266,7 +264,8 @@ private:
266
264
  ggml_tensor * shift,
267
265
  ggml_tensor * factors,
268
266
  float freq_base,
269
- float freq_scale) const;
267
+ float freq_scale,
268
+ uint32_t il) const;
270
269
 
271
270
  ggml_cgraph * build_graph_shift(
272
271
  llm_graph_result * res,