whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -1,6 +1,7 @@
1
1
  #pragma once
2
2
 
3
3
  #include "llama-arch.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-hparams.h"
5
6
  #include "llama-adapter.h"
6
7
 
@@ -14,13 +15,12 @@ struct ggml_cgraph;
14
15
  struct ggml_context;
15
16
  struct ggml_tensor;
16
17
 
17
- struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
20
  struct llama_memory_context_i;
21
21
 
22
- class llama_kv_cache_unified_context;
23
- class llama_kv_cache_unified_iswa_context;
22
+ class llama_kv_cache_context;
23
+ class llama_kv_cache_iswa_context;
24
24
  class llama_memory_recurrent_context;
25
25
  class llama_memory_hybrid_context;
26
26
 
@@ -39,6 +39,7 @@ enum llm_ffn_op_type {
39
39
  LLM_FFN_SWIGLU,
40
40
  LLM_FFN_GEGLU,
41
41
  LLM_FFN_REGLU,
42
+ LLM_FFN_SWIGLU_OAI_MOE,
42
43
  };
43
44
 
44
45
  enum llm_ffn_gate_type {
@@ -69,20 +70,38 @@ struct llama_cross {
69
70
  std::vector<std::set<llama_seq_id>> seq_ids_enc;
70
71
  };
71
72
 
73
+ struct llm_graph_params;
74
+
72
75
  //
73
76
  // llm_graph_input
74
77
  //
75
78
 
76
79
  class llm_graph_input_i {
77
80
  public:
81
+ llm_graph_input_i() {
82
+ const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
83
+ debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
84
+ }
85
+
78
86
  virtual ~llm_graph_input_i() = default;
79
87
 
80
88
  virtual void set_input(const llama_ubatch * ubatch) = 0;
89
+
90
+ // return true if the resulting input tensors using the provided graph parameters would be
91
+ // the same as the previous input tensors that we have currently stored in the object
92
+ virtual bool can_reuse(const llm_graph_params & params) {
93
+ // returning false here by default will prevent from reusing the graph if the check
94
+ // for the input type has not been implemented yet
95
+ GGML_UNUSED(params);
96
+ return false;
97
+ }
98
+ protected:
99
+ // env: LLAMA_GRAPH_INPUT_DEBUG
100
+ int debug = 0;
81
101
  };
82
102
 
83
103
  using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
84
104
 
85
-
86
105
  class llm_graph_input_embd : public llm_graph_input_i {
87
106
  public:
88
107
  llm_graph_input_embd() = default;
@@ -90,6 +109,8 @@ public:
90
109
 
91
110
  void set_input(const llama_ubatch * ubatch) override;
92
111
 
112
+ bool can_reuse(const llm_graph_params & params) override;
113
+
93
114
  ggml_tensor * tokens = nullptr; // I32 [n_batch]
94
115
  ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
95
116
  };
@@ -101,6 +122,8 @@ public:
101
122
 
102
123
  void set_input(const llama_ubatch * ubatch) override;
103
124
 
125
+ bool can_reuse(const llm_graph_params & params) override;
126
+
104
127
  ggml_tensor * pos = nullptr; // I32 [n_batch]
105
128
 
106
129
  const uint32_t n_pos_per_embd = 1;
@@ -130,23 +153,23 @@ public:
130
153
 
131
154
  ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
132
155
 
133
- const llama_hparams & hparams;
156
+ const llama_hparams hparams;
134
157
  };
135
158
 
136
159
  class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
137
160
  public:
138
161
  llm_graph_input_pos_bucket_kv(
139
162
  const llama_hparams & hparams,
140
- const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
163
+ const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
141
164
  virtual ~llm_graph_input_pos_bucket_kv() = default;
142
165
 
143
166
  void set_input(const llama_ubatch * ubatch) override;
144
167
 
145
168
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
146
169
 
147
- const llama_hparams & hparams;
170
+ const llama_hparams hparams;
148
171
 
149
- const llama_kv_cache_unified_context * mctx;
172
+ const llama_kv_cache_context * mctx;
150
173
  };
151
174
 
152
175
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -154,17 +177,19 @@ public:
154
177
  llm_graph_input_out_ids(
155
178
  const llama_hparams & hparams,
156
179
  const llama_cparams & cparams,
157
- int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
180
+ uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
158
181
  virtual ~llm_graph_input_out_ids() = default;
159
182
 
160
183
  void set_input(const llama_ubatch * ubatch) override;
161
184
 
185
+ bool can_reuse(const llm_graph_params & params) override;
186
+
162
187
  ggml_tensor * out_ids; // I32 [n_outputs]
163
188
 
164
- const llama_hparams & hparams;
165
- const llama_cparams & cparams;
189
+ const llama_hparams hparams;
190
+ const llama_cparams cparams;
166
191
 
167
- const int32_t n_outputs;
192
+ const uint32_t n_outputs;
168
193
  };
169
194
 
170
195
  class llm_graph_input_mean : public llm_graph_input_i {
@@ -176,19 +201,20 @@ public:
176
201
 
177
202
  ggml_tensor * mean; // F32 [n_batch, n_batch]
178
203
 
179
- const llama_cparams & cparams;
204
+ const llama_cparams cparams;
180
205
  };
181
206
 
182
207
  class llm_graph_input_cls : public llm_graph_input_i {
183
208
  public:
184
- llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
209
+ llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
185
210
  virtual ~llm_graph_input_cls() = default;
186
211
 
187
212
  void set_input(const llama_ubatch * ubatch) override;
188
213
 
189
214
  ggml_tensor * cls; // I32 [n_batch]
190
215
 
191
- const llama_cparams & cparams;
216
+ const llama_cparams cparams;
217
+ const llm_arch arch;
192
218
  };
193
219
 
194
220
  class llm_graph_input_rs : public llm_graph_input_i {
@@ -198,7 +224,12 @@ public:
198
224
 
199
225
  void set_input(const llama_ubatch * ubatch) override;
200
226
 
201
- ggml_tensor * s_copy; // I32 [kv_size]
227
+ ggml_tensor * s_copy; // I32 [n_rs]
228
+
229
+ // views of s_copy, computed once per graph
230
+ // and shared across layers which use build_rs
231
+ ggml_tensor * s_copy_main; // I32 [n_seqs]
232
+ ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
202
233
 
203
234
  const llama_memory_recurrent_context * mctx;
204
235
  };
@@ -228,64 +259,87 @@ public:
228
259
 
229
260
  ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
230
261
 
231
- ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
232
- ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
262
+ ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
263
+ ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
233
264
 
234
- const llama_hparams & hparams;
235
- const llama_cparams & cparams;
265
+ const llama_hparams hparams;
266
+ const llama_cparams cparams;
236
267
  };
237
268
 
238
- class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
269
+ class llm_graph_input_attn_kv : public llm_graph_input_i {
239
270
  public:
240
- llm_graph_input_attn_kv_unified(
271
+ llm_graph_input_attn_kv(
241
272
  const llama_hparams & hparams,
242
273
  const llama_cparams & cparams,
243
- const llama_kv_cache_unified_context * mctx) :
274
+ const llama_kv_cache_context * mctx) :
244
275
  hparams(hparams),
245
276
  cparams(cparams),
246
277
  mctx(mctx) {
247
278
  }
248
- ~llm_graph_input_attn_kv_unified() = default;
279
+ ~llm_graph_input_attn_kv() = default;
249
280
 
250
281
  void set_input(const llama_ubatch * ubatch) override;
251
282
 
283
+ bool can_reuse(const llm_graph_params & params) override;
284
+
285
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
286
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
287
+
252
288
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
253
289
 
254
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
255
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
290
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
291
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
256
292
 
257
- const llama_hparams & hparams;
258
- const llama_cparams & cparams;
293
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
294
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
295
+
296
+ // note: these have to be copies because in order to be able to reuse a graph, its inputs
297
+ // need to carry these parameters with them. otherwise, they can point to freed
298
+ // llm_graph_params from a previous batch, causing stack-use-after-return
299
+ const llama_hparams hparams;
300
+ const llama_cparams cparams;
259
301
 
260
- const llama_kv_cache_unified_context * mctx;
302
+ const llama_kv_cache_context * mctx;
261
303
  };
262
304
 
263
- class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
305
+ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
264
306
  public:
265
- llm_graph_input_attn_kv_unified_iswa(
307
+ llm_graph_input_attn_kv_iswa(
266
308
  const llama_hparams & hparams,
267
309
  const llama_cparams & cparams,
268
- const llama_kv_cache_unified_iswa_context * mctx) :
310
+ const llama_kv_cache_iswa_context * mctx) :
269
311
  hparams(hparams),
270
312
  cparams(cparams),
271
313
  mctx(mctx) {
272
314
  }
273
- ~llm_graph_input_attn_kv_unified_iswa() = default;
315
+ ~llm_graph_input_attn_kv_iswa() = default;
274
316
 
275
317
  void set_input(const llama_ubatch * ubatch) override;
276
318
 
319
+ bool can_reuse(const llm_graph_params & params) override;
320
+
321
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
322
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
323
+ ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
324
+ ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
325
+
277
326
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
278
327
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
279
328
 
280
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
281
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
282
- ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
283
- ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
329
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
330
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
331
+ ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
332
+ ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
284
333
 
285
- const llama_hparams & hparams;
286
- const llama_cparams & cparams;
334
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
335
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
336
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
337
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
287
338
 
288
- const llama_kv_cache_unified_iswa_context * mctx;
339
+ const llama_hparams hparams;
340
+ const llama_cparams cparams;
341
+
342
+ const llama_kv_cache_iswa_context * mctx;
289
343
  };
290
344
 
291
345
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -297,8 +351,8 @@ public:
297
351
 
298
352
  ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
299
353
 
300
- ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
301
- ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
354
+ ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
355
+ ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
302
356
 
303
357
  const llama_cross * cross = nullptr;
304
358
  };
@@ -306,41 +360,25 @@ public:
306
360
  class llm_graph_input_mem_hybrid : public llm_graph_input_i {
307
361
  public:
308
362
  llm_graph_input_mem_hybrid(
309
- const llama_hparams & hparams,
310
- const llama_cparams & cparams,
311
- const llama_memory_hybrid_context * mctx) :
312
- hparams(hparams),
313
- cparams(cparams),
314
- mctx(mctx) {
315
- }
363
+ std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
364
+ std::unique_ptr<llm_graph_input_rs> inp_rs,
365
+ const llama_memory_hybrid_context * mctx) :
366
+ inp_attn(std::move(inp_attn)),
367
+ inp_rs(std::move(inp_rs)),
368
+ mctx(mctx) { }
316
369
  virtual ~llm_graph_input_mem_hybrid() = default;
317
370
 
318
371
  void set_input(const llama_ubatch * ubatch) override;
319
372
 
320
- ggml_tensor * s_copy; // I32 [kv_size]
321
-
322
- ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
373
+ std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
374
+ std::unique_ptr<llm_graph_input_rs> inp_rs;
323
375
 
324
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
325
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
326
-
327
- const llama_hparams & hparams;
328
- const llama_cparams & cparams;
376
+ llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
377
+ llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
329
378
 
330
379
  const llama_memory_hybrid_context * mctx;
331
380
  };
332
381
 
333
- // TODO: remove this when ggml_scale_add is implemented
334
- class llm_graph_input_one : public llm_graph_input_i {
335
- public:
336
- llm_graph_input_one() {}
337
- virtual ~llm_graph_input_one() = default;
338
-
339
- void set_input(const llama_ubatch *) override;
340
-
341
- ggml_tensor * one = nullptr; // F32
342
- };
343
-
344
382
  //
345
383
  // llm_graph_result
346
384
  //
@@ -351,40 +389,110 @@ public:
351
389
  // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
352
390
  // these are used by the llama_context to extact the relevant data, based on the compute parameters
353
391
 
354
- class llm_graph_result_i {
355
- public:
356
- virtual ~llm_graph_result_i() = default;
392
+ // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
393
+ using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
357
394
 
358
- virtual ggml_tensor * get_tokens() = 0;
359
- virtual ggml_tensor * get_logits() = 0;
360
- virtual ggml_tensor * get_embd() = 0;
361
- virtual ggml_tensor * get_embd_pooled() = 0;
395
+ class llm_graph_result;
362
396
 
363
- virtual void set_inputs(const llama_ubatch * ubatch) = 0;
364
- };
397
+ struct llm_graph_params {
398
+ llm_arch arch = LLM_ARCH_UNKNOWN;
365
399
 
366
- using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
400
+ llama_hparams hparams;
401
+ llama_cparams cparams;
367
402
 
403
+ llama_ubatch ubatch; // note: intentionally make a copy
368
404
 
369
- class llm_graph_result : public llm_graph_result_i {
370
- public:
371
- virtual ~llm_graph_result() = default;
405
+ llm_graph_type gtype;
406
+
407
+ ggml_backend_sched_t sched;
408
+ ggml_backend_t backend_cpu;
372
409
 
373
- ggml_tensor * get_tokens() override { return t_tokens; }
374
- ggml_tensor * get_logits() override { return t_logits; }
375
- ggml_tensor * get_embd() override { return t_embd; }
376
- ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
410
+ const llama_adapter_cvec * cvec;
411
+ const llama_adapter_loras * loras;
412
+ const llama_memory_context_i * mctx;
413
+ const llama_cross * cross;
377
414
 
378
- void set_inputs(const llama_ubatch * ubatch) override {
379
- for (auto & input : inputs) {
380
- input->set_input(ubatch);
415
+ uint32_t n_outputs;
416
+
417
+ llm_graph_cb cb;
418
+
419
+ llm_graph_result * res;
420
+
421
+ // return true if the "other" params would result in a graph with the same topology as with the current params
422
+ // having the same topology allows us to reuse the graph in some cases
423
+ bool allow_reuse(const llm_graph_params & other) const {
424
+ // first check the ubatch
425
+ bool can_reuse_ubatch =
426
+ ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
427
+ ubatch.n_tokens == other.ubatch.n_tokens &&
428
+ ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
429
+ ubatch.n_seqs == other.ubatch.n_seqs &&
430
+ ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
431
+ (
432
+ (!ubatch.token && !other.ubatch.token) ||
433
+ (!ubatch.embd && !other.ubatch.embd)
434
+ );
435
+
436
+ // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
437
+ // the reason is because the set of attention streams would be different for different sequences
438
+ if (can_reuse_ubatch && ubatch.equal_seqs()) {
439
+ if (!ubatch.data) {
440
+ // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
441
+ // therefore we cannot perform the sequence id check. normally should never happen
442
+ can_reuse_ubatch = false;
443
+ } else {
444
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
445
+ can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
446
+ }
447
+ }
448
+ }
449
+
450
+ if (!can_reuse_ubatch) {
451
+ return false;
381
452
  }
382
- }
383
453
 
384
- llm_graph_input_i * add_input(llm_graph_input_ptr input) {
385
- inputs.emplace_back(std::move(input));
386
- return inputs.back().get();
454
+ return
455
+ cparams.embeddings == other.cparams.embeddings &&
456
+ cparams.causal_attn == other.cparams.causal_attn &&
457
+ arch == other.arch &&
458
+ gtype == other.gtype &&
459
+ cvec == other.cvec &&
460
+ loras == other.loras &&
461
+ cross == other.cross &&
462
+ n_outputs == other.n_outputs;
387
463
  }
464
+ };
465
+
466
+ class llm_graph_result {
467
+ public:
468
+ llm_graph_result(int64_t max_nodes);
469
+
470
+ virtual ~llm_graph_result() = default;
471
+
472
+ ggml_tensor * get_tokens() const { return t_tokens; }
473
+ ggml_tensor * get_logits() const { return t_logits; }
474
+ ggml_tensor * get_embd() const { return t_embd; }
475
+ ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
476
+
477
+ ggml_cgraph * get_gf() const { return gf; }
478
+ ggml_context * get_ctx() const { return ctx_compute.get(); }
479
+
480
+ int64_t get_max_nodes() const;
481
+
482
+ void reset();
483
+
484
+ void set_inputs(const llama_ubatch * ubatch);
485
+
486
+ // try to update the existing graph result using the new graph parameters in order to reuse it
487
+ // this can only be done if we determine that the resulting graph using the new graph parameters
488
+ // would be identical to the existing graph. in that case, we simply have to update the memory
489
+ // contexts of the input tensors of the graph and we can reuse it for another computation
490
+ // return true if the graph was updated and can be reused
491
+ bool can_reuse(const llm_graph_params & params);
492
+
493
+ llm_graph_input_i * add_input(llm_graph_input_ptr input);
494
+
495
+ void set_params(const llm_graph_params & params);
388
496
 
389
497
  // important graph nodes
390
498
  ggml_tensor * t_tokens = nullptr;
@@ -393,36 +501,34 @@ public:
393
501
  ggml_tensor * t_embd_pooled = nullptr;
394
502
 
395
503
  std::vector<llm_graph_input_ptr> inputs;
396
- };
397
504
 
398
- //
399
- // llm_graph_context
400
- //
505
+ ggml_context_ptr ctx_compute;
401
506
 
402
- // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
403
- using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
507
+ // memory buffers used to evaluate the model
508
+ std::vector<uint8_t> buf_compute_meta;
404
509
 
405
- struct llm_graph_params {
406
- ggml_context * ctx;
510
+ ggml_cgraph * gf;
407
511
 
408
- const llm_arch arch;
512
+ int64_t max_nodes;
409
513
 
410
- const llama_hparams & hparams;
411
- const llama_cparams & cparams;
412
- const llama_ubatch & ubatch;
514
+ private:
515
+ // keep a copy of the previous graph parameters
516
+ // we will use this to determine whether the graph can be reused by comparing them with the new parameters
517
+ // note: these are updated after constructing the new graph
518
+ llm_graph_params params;
413
519
 
414
- ggml_backend_sched_t sched;
415
- ggml_backend_t backend_cpu;
520
+ // env: LLAMA_GRAPH_RESULT_DEBUG
521
+ int debug = 0;
522
+ };
416
523
 
417
- const llama_adapter_cvec * cvec;
418
- const llama_adapter_loras * loras;
419
- const llama_memory_context_i * mctx;
420
- const llama_cross * cross;
524
+ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
421
525
 
422
- uint32_t n_outputs;
526
+ //
527
+ // llm_graph_context
528
+ //
423
529
 
424
- const llm_graph_cb & cb;
425
- };
530
+ // used in build_rs to properly order writes and avoid unnecessary copies
531
+ using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
426
532
 
427
533
  struct llm_graph_context {
428
534
  const llm_arch arch;
@@ -460,8 +566,6 @@ struct llm_graph_context {
460
566
  const enum llama_pooling_type pooling_type;
461
567
  const enum llama_rope_type rope_type;
462
568
 
463
- ggml_context * ctx0 = nullptr;
464
-
465
569
  ggml_backend_sched_t sched;
466
570
 
467
571
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
@@ -473,7 +577,10 @@ struct llm_graph_context {
473
577
 
474
578
  const llm_graph_cb & cb_func;
475
579
 
476
- std::unique_ptr<llm_graph_result> res;
580
+ llm_graph_result * res;
581
+
582
+ ggml_context * ctx0 = nullptr;
583
+ ggml_cgraph * gf = nullptr;
477
584
 
478
585
  llm_graph_context(const llm_graph_params & params);
479
586
  virtual ~llm_graph_context() = default;
@@ -522,6 +629,7 @@ struct llm_graph_context {
522
629
  llm_ffn_gate_type type_gate,
523
630
  int il) const;
524
631
 
632
+ // build MoE FFN without bias tensors
525
633
  ggml_tensor * build_moe_ffn(
526
634
  ggml_tensor * cur,
527
635
  ggml_tensor * gate_inp,
@@ -536,7 +644,29 @@ struct llm_graph_context {
536
644
  bool scale_w,
537
645
  float w_scale,
538
646
  llama_expert_gating_func_type gating_op,
539
- int il) const;
647
+ int il,
648
+ ggml_tensor * probs_in = nullptr) const;
649
+
650
+ ggml_tensor * build_moe_ffn(
651
+ ggml_tensor * cur,
652
+ ggml_tensor * gate_inp,
653
+ ggml_tensor * gate_inp_b,
654
+ ggml_tensor * up_exps,
655
+ ggml_tensor * up_exps_b,
656
+ ggml_tensor * gate_exps,
657
+ ggml_tensor * gate_exps_b,
658
+ ggml_tensor * down_exps,
659
+ ggml_tensor * down_exps_b,
660
+ ggml_tensor * exp_probs_b,
661
+ int64_t n_expert,
662
+ int64_t n_expert_used,
663
+ llm_ffn_op_type type_op,
664
+ bool norm_w,
665
+ bool scale_w,
666
+ float w_scale,
667
+ llama_expert_gating_func_type gating_op,
668
+ int il,
669
+ ggml_tensor * probs_in = nullptr) const;
540
670
 
541
671
  //
542
672
  // inputs
@@ -554,64 +684,63 @@ struct llm_graph_context {
554
684
  ggml_tensor * build_inp_pos_bucket_dec() const;
555
685
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
556
686
 
557
- llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
558
-
559
687
  //
560
688
  // attention
561
689
  //
562
690
 
563
691
  ggml_tensor * build_attn_mha(
564
- ggml_cgraph * gf,
565
- ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
566
- ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
567
- ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
568
- ggml_tensor * kq_b,
569
- ggml_tensor * kq_mask,
570
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
571
- float kq_scale) const;
692
+ ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
693
+ ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
694
+ ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
695
+ ggml_tensor * kq_b,
696
+ ggml_tensor * kq_mask,
697
+ ggml_tensor * sinks, // [n_head_q]
698
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
699
+ float kq_scale,
700
+ int il) const;
572
701
 
573
702
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
574
703
 
575
704
  ggml_tensor * build_attn(
576
705
  llm_graph_input_attn_no_cache * inp,
577
- ggml_cgraph * gf,
578
706
  ggml_tensor * wo,
579
707
  ggml_tensor * wo_b,
580
708
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581
709
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582
710
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
583
711
  ggml_tensor * kq_b,
712
+ ggml_tensor * sinks, // [n_head_q]
584
713
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
585
714
  float kq_scale,
586
715
  int il) const;
587
716
 
588
- llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
717
+ llm_graph_input_attn_kv * build_attn_inp_kv() const;
589
718
 
590
719
  ggml_tensor * build_attn(
591
- llm_graph_input_attn_kv_unified * inp,
592
- ggml_cgraph * gf,
720
+ llm_graph_input_attn_kv * inp,
593
721
  ggml_tensor * wo,
594
722
  ggml_tensor * wo_b,
595
723
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
596
724
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
597
725
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
598
726
  ggml_tensor * kq_b,
727
+ ggml_tensor * sinks, // [n_head_q]
599
728
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
600
729
  float kq_scale,
601
730
  int il) const;
602
731
 
603
- llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
732
+ llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
604
733
 
605
734
  // note: if k_cur or v_cur are not provided, they will not be stored in the memory
606
735
  ggml_tensor * build_attn(
607
- llm_graph_input_attn_kv_unified_iswa * inp,
608
- ggml_cgraph * gf,
736
+ llm_graph_input_attn_kv_iswa * inp,
609
737
  ggml_tensor * wo,
610
738
  ggml_tensor * wo_b,
611
739
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
612
740
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
613
741
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
614
742
  ggml_tensor * kq_b,
743
+ ggml_tensor * sinks, // [n_head_q]
615
744
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
616
745
  float kq_scale,
617
746
  int il) const;
@@ -620,86 +749,67 @@ struct llm_graph_context {
620
749
 
621
750
  ggml_tensor * build_attn(
622
751
  llm_graph_input_attn_cross * inp,
623
- ggml_cgraph * gf,
624
752
  ggml_tensor * wo,
625
753
  ggml_tensor * wo_b,
626
754
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
627
755
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
628
756
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
629
757
  ggml_tensor * kq_b,
758
+ ggml_tensor * sinks, // [n_head_q]
630
759
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
631
760
  float kq_scale,
632
761
  int il) const;
633
762
 
634
- ggml_tensor * build_attn(
635
- llm_graph_input_mem_hybrid * inp,
636
- ggml_cgraph * gf,
637
- ggml_tensor * wo,
638
- ggml_tensor * wo_b,
639
- ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
640
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
641
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
642
- ggml_tensor * kq_b,
643
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
644
- float kq_scale,
645
- int il) const;
646
763
  //
647
764
  // recurrent
648
765
  //
649
766
 
650
- // TODO: avoid notion of "kv"
651
767
  // TODO: move this implementation to llama_memory_recurrent.
652
- // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
768
+ // this is analogous to llama_kv_cache::cpy_k / cpy_v
653
769
  // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
654
770
  // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
655
771
  // `llama_memory_recurrent`
656
772
  ggml_tensor * build_rs(
657
- ggml_cgraph * gf,
658
773
  ggml_tensor * s,
659
- ggml_tensor * state_copy,
774
+ ggml_tensor * state_copy_main,
775
+ ggml_tensor * state_copy_extra,
660
776
  int32_t state_size,
661
777
  int32_t n_seqs,
662
- uint32_t n_kv,
663
- uint32_t kv_head,
664
- uint32_t kv_size,
778
+ uint32_t n_rs,
779
+ uint32_t rs_head,
780
+ uint32_t rs_size,
665
781
  int32_t rs_zero,
666
- bool avoid_copies = false) const;
782
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
667
783
 
668
784
  llm_graph_input_rs * build_rs_inp() const;
669
785
 
670
786
  ggml_tensor * build_rs(
671
787
  llm_graph_input_rs * inp,
672
- ggml_cgraph * gf,
673
- ggml_tensor * s,
674
- int32_t state_size,
675
- int32_t n_seqs,
676
- bool avoid_copies = false) const;
677
-
678
- ggml_tensor * build_rs(
679
- llm_graph_input_mem_hybrid * inp,
680
- ggml_cgraph * gf,
681
788
  ggml_tensor * s,
682
789
  int32_t state_size,
683
790
  int32_t n_seqs,
684
- bool avoid_copies = false) const;
791
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
685
792
 
686
793
  ggml_tensor * build_rwkv_token_shift_load(
687
794
  llm_graph_input_rs * inp,
688
- ggml_cgraph * gf,
689
795
  const llama_ubatch & ubatch,
690
- int il) const;
796
+ int il) const;
691
797
 
692
798
  ggml_tensor * build_rwkv_token_shift_store(
693
799
  ggml_tensor * token_shift,
694
800
  const llama_ubatch & ubatch,
695
801
  int il) const;
802
+ //
803
+ // hybrid
804
+ //
805
+
806
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
696
807
 
697
808
  //
698
809
  // pooling
699
810
  //
700
811
 
701
812
  void build_pooling(
702
- ggml_cgraph * gf,
703
813
  ggml_tensor * cls,
704
814
  ggml_tensor * cls_b,
705
815
  ggml_tensor * cls_out,