whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -35,14 +35,12 @@ llama_context::llama_context(
35
35
 
36
36
  cparams.n_threads = params.n_threads;
37
37
  cparams.n_threads_batch = params.n_threads_batch;
38
- cparams.yarn_ext_factor = params.yarn_ext_factor;
39
- cparams.yarn_attn_factor = params.yarn_attn_factor;
40
- cparams.yarn_beta_fast = params.yarn_beta_fast;
41
- cparams.yarn_beta_slow = params.yarn_beta_slow;
42
- cparams.defrag_thold = params.defrag_thold;
38
+ cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
39
+ cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
40
+ cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
41
+ cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
43
42
  cparams.embeddings = params.embeddings;
44
43
  cparams.offload_kqv = params.offload_kqv;
45
- cparams.flash_attn = params.flash_attn;
46
44
  cparams.no_perf = params.no_perf;
47
45
  cparams.pooling_type = params.pooling_type;
48
46
  cparams.warmup = false;
@@ -87,21 +85,32 @@ llama_context::llama_context(
87
85
  cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
88
86
  }
89
87
 
88
+ cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
89
+
90
90
  // with causal attention, the batch size is limited by the context size
91
91
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
92
92
 
93
93
  // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
94
94
  // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
95
95
  // ref: https://github.com/ggerganov/llama.cpp/pull/5021
96
- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
96
+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
97
97
  if (cparams.n_batch < GGML_KQ_MASK_PAD) {
98
98
  LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
99
  cparams.n_batch = GGML_KQ_MASK_PAD;
100
100
  }
101
-
102
101
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
103
102
 
104
103
  cparams.op_offload = params.op_offload;
104
+ cparams.kv_unified = params.kv_unified;
105
+
106
+ {
107
+ const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
108
+ graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
109
+
110
+ if (graph_reuse_disable) {
111
+ LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__);
112
+ }
113
+ }
105
114
 
106
115
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
107
116
 
@@ -111,7 +120,8 @@ llama_context::llama_context(
111
120
  LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
112
121
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
113
122
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
114
- LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
123
+ LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
124
+ LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
115
125
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
116
126
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
117
127
 
@@ -125,11 +135,6 @@ llama_context::llama_context(
125
135
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
126
136
  }
127
137
 
128
- if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
129
- LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
130
- __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
131
- }
132
-
133
138
  if (!hparams.vocab_only) {
134
139
  // GPU backends
135
140
  for (auto * dev : model.devices) {
@@ -176,7 +181,7 @@ llama_context::llama_context(
176
181
  // graph outputs buffer
177
182
  {
178
183
  // resized during inference when a batch uses more outputs
179
- if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
184
+ if (output_reserve(params.n_seq_max) < params.n_seq_max) {
180
185
  throw std::runtime_error("failed to reserve initial output buffer");
181
186
  }
182
187
 
@@ -227,8 +232,8 @@ llama_context::llama_context(
227
232
 
228
233
  LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
229
234
 
230
- // buffer used to store the computation graph and the tensor meta data
231
- buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
235
+ gf_res_prev.reset(new llm_graph_result(max_nodes));
236
+ gf_res_reserve.reset(new llm_graph_result(max_nodes));
232
237
 
233
238
  // TODO: move these checks to ggml_backend_sched
234
239
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -265,29 +270,76 @@ llama_context::llama_context(
265
270
  }
266
271
  }
267
272
 
268
- // reserve worst-case graph
269
- if (!hparams.vocab_only && memory) {
270
- const uint32_t n_seqs = cparams.n_seq_max;
273
+ if (!hparams.vocab_only) {
274
+ llama_memory_context_ptr mctx;
275
+ if (memory) {
276
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
277
+ mctx = memory->init_full();
278
+ if (!mctx) {
279
+ throw std::runtime_error("failed to initialize memory module");
280
+ }
281
+ }
282
+
283
+ cross.v_embd.clear();
284
+
285
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
271
286
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
272
287
 
288
+ // avoid reserving graphs with zero outputs - assume one output per sequence
289
+ n_outputs = n_seqs;
290
+
273
291
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
274
292
 
293
+ // resolve automatic Flash Attention use
294
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
295
+ auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
296
+ if (!gf) {
297
+ throw std::runtime_error("failed to split graph for Flash Attention check");
298
+ }
299
+
300
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
301
+ bool fa_device_mismatch = false;
302
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
303
+ ggml_tensor * n = ggml_graph_node(gf, i);
304
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
305
+ continue;
306
+ }
307
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
308
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
309
+
310
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
311
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
312
+ const int il = std::stoi(n->name + prefix_len);
313
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
314
+ if (device_fa != device_kv) {
315
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
316
+ "is assigned to device %s (usually due to missing support)\n",
317
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
318
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
319
+ fa_device_mismatch = true;
320
+ break;
321
+ }
322
+ }
323
+ if (fa_device_mismatch) {
324
+ cparams.flash_attn = false;
325
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
326
+ if (ggml_is_quantized(params.type_v)) {
327
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
328
+ }
329
+ } else {
330
+ cparams.flash_attn = true;
331
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
332
+ }
333
+ }
334
+
335
+ // reserve worst-case graph
275
336
  int n_splits_pp = -1;
276
337
  int n_nodes_pp = -1;
277
338
 
278
339
  int n_splits_tg = -1;
279
340
  int n_nodes_tg = -1;
280
341
 
281
- // simulate full KV cache
282
-
283
- const auto mctx = memory->init_full();
284
- if (!mctx) {
285
- throw std::runtime_error("failed to initialize KV cache");
286
- }
287
-
288
- cross.v_embd.clear();
289
-
290
- // reserve pp graph first so that buffers are only allocated once
342
+ // reserve pp (prompt processing) graph first so that buffers are only allocated once
291
343
  {
292
344
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293
345
  if (!gf) {
@@ -298,9 +350,9 @@ llama_context::llama_context(
298
350
  n_nodes_pp = ggml_graph_n_nodes(gf);
299
351
  }
300
352
 
301
- // reserve with tg graph to get the number of splits and nodes
353
+ // reserve with tg (token generation) graph to get the number of splits and nodes
302
354
  {
303
- auto * gf = graph_reserve(1, 1, 1, mctx.get());
355
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
304
356
  if (!gf) {
305
357
  throw std::runtime_error("failed to allocate compute tg buffers");
306
358
  }
@@ -311,6 +363,10 @@ llama_context::llama_context(
311
363
 
312
364
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
313
365
  {
366
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
367
+ //
368
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
369
+ //
314
370
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
371
  if (!gf) {
316
372
  throw std::runtime_error("failed to allocate compute pp buffers");
@@ -388,10 +444,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388
444
  return sched.get();
389
445
  }
390
446
 
391
- ggml_context * llama_context::get_ctx_compute() const {
392
- return ctx_compute.get();
393
- }
394
-
395
447
  uint32_t llama_context::n_ctx() const {
396
448
  return cparams.n_ctx;
397
449
  }
@@ -424,26 +476,12 @@ llama_memory_t llama_context::get_memory() const {
424
476
  return memory.get();
425
477
  }
426
478
 
427
- // deprecated
428
- void llama_context::kv_self_defrag_sched() {
429
- if (!memory) {
430
- return;
431
- }
432
-
433
- memory_force_optimize = true;
434
- }
435
-
436
- // deprecated
437
- bool llama_context::kv_self_update(bool optimize) {
479
+ bool llama_context::memory_update(bool optimize) {
438
480
  if (!memory) {
439
481
  return false;
440
482
  }
441
483
 
442
484
  {
443
- // TODO: remove in the future
444
- optimize |= memory_force_optimize;
445
- memory_force_optimize = false;
446
-
447
485
  const auto mctx = memory->init_update(this, optimize);
448
486
  switch (mctx->get_status()) {
449
487
  case LLAMA_MEMORY_STATUS_SUCCESS:
@@ -463,6 +501,11 @@ bool llama_context::kv_self_update(bool optimize) {
463
501
  }
464
502
  }
465
503
 
504
+ // reset the previous graph result to make sure that it won't be reused
505
+ // TODO: change the mctx->apply() to return information if a graph reserve is needed
506
+ // reset the graph result only if the memory module did reset the scheduler
507
+ gf_res_prev->reset();
508
+
466
509
  if (!mctx->apply()) {
467
510
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
511
  }
@@ -475,7 +518,7 @@ bool llama_context::kv_self_update(bool optimize) {
475
518
  throw std::runtime_error("failed to initialize memory context");
476
519
  }
477
520
 
478
- const uint32_t n_seqs = cparams.n_seq_max;
521
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
479
522
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
523
 
481
524
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -492,12 +535,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
492
535
  }
493
536
 
494
537
  float * llama_context::get_logits() {
538
+ output_reorder();
539
+
495
540
  return logits;
496
541
  }
497
542
 
498
543
  float * llama_context::get_logits_ith(int32_t i) {
499
544
  int64_t j = -1;
500
545
 
546
+ output_reorder();
547
+
501
548
  try {
502
549
  if (logits == nullptr) {
503
550
  throw std::runtime_error("no logits");
@@ -534,12 +581,16 @@ float * llama_context::get_logits_ith(int32_t i) {
534
581
  }
535
582
 
536
583
  float * llama_context::get_embeddings() {
584
+ output_reorder();
585
+
537
586
  return embd;
538
587
  }
539
588
 
540
589
  float * llama_context::get_embeddings_ith(int32_t i) {
541
590
  int64_t j = -1;
542
591
 
592
+ output_reorder();
593
+
543
594
  try {
544
595
  if (embd == nullptr) {
545
596
  throw std::runtime_error("no embeddings");
@@ -678,38 +729,59 @@ bool llama_context::apply_adapter_cvec(
678
729
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
679
730
  }
680
731
 
681
- llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
732
+ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
733
  if (mctx && !mctx->apply()) {
683
734
  LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
735
  ret = GGML_STATUS_FAILED;
685
736
  return nullptr;
686
737
  }
687
738
 
688
- auto * gf = graph_init();
689
- if (!gf) {
690
- LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
- ret = GGML_STATUS_FAILED;
692
- return nullptr;
693
- }
739
+ auto * res = gf_res_prev.get();
740
+ auto * gf = res->get_gf();
694
741
 
695
- auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
- if (!res) {
697
- LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
- ret = GGML_STATUS_FAILED;
699
- return nullptr;
700
- }
742
+ // the new graph parameters
743
+ // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
744
+ const auto gparams = graph_params(res, ubatch, mctx, gtype);
701
745
 
702
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
746
+ if (!graph_reuse_disable && res->can_reuse(gparams)) {
747
+ //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
703
748
 
704
- if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
- LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
- ret = GGML_STATUS_ALLOC_FAILED;
707
- return nullptr;
749
+ n_reused++;
750
+ } else {
751
+ res->reset();
752
+
753
+ ggml_backend_sched_reset(sched.get());
754
+ ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
755
+
756
+ //const auto t_start_us = ggml_time_us();
757
+
758
+ gf = model.build_graph(gparams);
759
+
760
+ //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
761
+
762
+ if (!gf) {
763
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
764
+ ret = GGML_STATUS_FAILED;
765
+ return nullptr;
766
+ }
767
+
768
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
769
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
770
+ ret = GGML_STATUS_ALLOC_FAILED;
771
+ return nullptr;
772
+ }
708
773
  }
709
774
 
710
- res->set_inputs(&ubatch);
775
+ // set the input data for the input tensors
776
+ {
777
+ //const auto t_start_us = ggml_time_us();
778
+
779
+ res->set_inputs(&ubatch);
780
+
781
+ //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
782
+ }
711
783
 
712
- const auto status = graph_compute(gf, ubatch.n_tokens > 1);
784
+ const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
713
785
  if (status != GGML_STATUS_SUCCESS) {
714
786
  LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
787
  ret = status;
@@ -731,16 +803,19 @@ int llama_context::encode(const llama_batch & batch_inp) {
731
803
 
732
804
  const auto & hparams = model.hparams;
733
805
 
734
- const int64_t n_embd = hparams.n_embd;
806
+ const int64_t n_embd = hparams.n_embd;
807
+ const int64_t n_vocab = model.vocab.n_tokens();
735
808
 
736
809
  // note: during encode, we always pass the full sequence starting from pos = 0
737
- if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
810
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
738
811
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
812
  return -1;
740
813
  }
741
814
 
742
815
  const uint32_t n_tokens = balloc->get_n_tokens();
743
816
 
817
+ // [TAG_NO_CACHE_PAD]
818
+ // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
744
819
  const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745
820
 
746
821
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@@ -767,9 +842,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
767
842
 
768
843
  n_outputs = n_tokens;
769
844
 
770
- ggml_backend_sched_reset(sched.get());
771
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
772
-
773
845
  const auto causal_attn_org = cparams.causal_attn;
774
846
 
775
847
  // always use non-causal attention for encoder graphs
@@ -778,7 +850,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778
850
  cparams.causal_attn = false;
779
851
 
780
852
  ggml_status status;
781
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
853
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
782
854
 
783
855
  cparams.causal_attn = causal_attn_org;
784
856
 
@@ -791,10 +863,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
791
863
  }
792
864
  }
793
865
 
866
+ auto * t_logits = res->get_logits();
794
867
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
795
868
 
869
+ // extract logits
870
+ if (logits && t_logits) {
871
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
872
+ GGML_ASSERT(backend_res != nullptr);
873
+ GGML_ASSERT(logits != nullptr);
874
+
875
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
876
+ }
877
+
796
878
  // extract embeddings
797
- if (t_embd) {
879
+ if (embd && t_embd) {
798
880
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
799
881
  GGML_ASSERT(backend_embd != nullptr);
800
882
 
@@ -844,10 +926,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
844
926
  }
845
927
  }
846
928
 
847
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848
- // overlap with device computation.
849
- ggml_backend_sched_reset(sched.get());
850
-
851
929
  // TODO: hacky solution
852
930
  if (model.arch == LLM_ARCH_T5 && t_embd) {
853
931
  //cross.t_embd = t_embd;
@@ -893,13 +971,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
893
971
  const auto & vocab = model.vocab;
894
972
  const auto & hparams = model.hparams;
895
973
 
896
- const int32_t n_vocab = vocab.n_tokens();
974
+ const int64_t n_vocab = vocab.n_tokens();
897
975
  const int64_t n_embd = hparams.n_embd;
898
976
 
899
977
  // when computing embeddings, all tokens are output
900
978
  const bool output_all = cparams.embeddings;
901
979
 
902
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
980
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
903
981
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
982
  return -1;
905
983
  }
@@ -927,11 +1005,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
927
1005
 
928
1006
  // TODO: this clear of the buffer can easily be forgotten - need something better
929
1007
  embd_seq.clear();
1008
+ output_swaps.clear();
930
1009
 
931
1010
  bool did_optimize = false;
932
1011
 
933
- // handle any pending defrags/shifts
934
- kv_self_update(false);
1012
+ // handle any pending shifts/copies
1013
+ memory_update(false);
935
1014
 
936
1015
  llama_memory_context_ptr mctx;
937
1016
 
@@ -956,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
956
1035
  if (!did_optimize) {
957
1036
  did_optimize = true;
958
1037
 
959
- if (kv_self_update(true)) {
1038
+ if (memory_update(true)) {
960
1039
  LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
961
1040
 
962
1041
  continue;
@@ -1005,14 +1084,11 @@ int llama_context::decode(const llama_batch & batch_inp) {
1005
1084
  n_outputs = n_outputs_new;
1006
1085
  }
1007
1086
 
1008
- ggml_backend_sched_reset(sched.get());
1009
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010
-
1011
1087
  ggml_status status;
1012
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1088
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1013
1089
 
1014
1090
  if (!res) {
1015
- // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1091
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
1016
1092
  llama_pos pos_min[LLAMA_MAX_SEQ];
1017
1093
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1018
1094
  pos_min[s] = std::numeric_limits<llama_pos>::max();
@@ -1029,7 +1105,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1029
1105
  continue;
1030
1106
  }
1031
1107
 
1032
- LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1108
+ LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1033
1109
 
1034
1110
  memory->seq_rm(s, pos_min[s], -1);
1035
1111
  }
@@ -1149,9 +1225,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1149
1225
  // make the outputs have the same order they had in the user-provided batch
1150
1226
  // note: this is mostly relevant for recurrent models atm
1151
1227
  if (!sorted_output) {
1152
- const uint32_t n_vocab = model.vocab.n_tokens();
1153
- const uint64_t n_embd = model.hparams.n_embd;
1154
-
1155
1228
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1156
1229
 
1157
1230
  // TODO: is there something more efficient which also minimizes swaps?
@@ -1167,16 +1240,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
1167
1240
  continue;
1168
1241
  }
1169
1242
  std::swap(out_ids[i], out_ids[j_min]);
1170
- if (logits_size > 0) {
1171
- for (uint32_t k = 0; k < n_vocab; k++) {
1172
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1173
- }
1174
- }
1175
- if (embd_size > 0) {
1176
- for (uint32_t k = 0; k < n_embd; k++) {
1177
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1178
- }
1179
- }
1243
+
1244
+ // remember the swaps and apply them lazily upon logits/embeddings access
1245
+ output_swaps.push_back({ i, j_min });
1180
1246
  }
1181
1247
 
1182
1248
  std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1190,10 +1256,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1190
1256
  // wait for the computation to finish (automatically done when obtaining the model output)
1191
1257
  //synchronize();
1192
1258
 
1193
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1194
- // overlap with device computation.
1195
- ggml_backend_sched_reset(sched.get());
1196
-
1197
1259
  return 0;
1198
1260
  }
1199
1261
 
@@ -1271,28 +1333,45 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1271
1333
  return n_outputs_max;
1272
1334
  }
1273
1335
 
1336
+ void llama_context::output_reorder() {
1337
+ const uint64_t n_vocab = model.vocab.n_tokens();
1338
+ const uint64_t n_embd = model.hparams.n_embd;
1339
+
1340
+ for (size_t s = 0; s < output_swaps.size(); ++s) {
1341
+ const uint64_t i0 = output_swaps[s].i0;
1342
+ const uint64_t i1 = output_swaps[s].i1;
1343
+
1344
+ if (logits_size > 0) {
1345
+ for (uint64_t k = 0; k < n_vocab; k++) {
1346
+ std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1347
+ }
1348
+ }
1349
+
1350
+ if (embd_size > 0) {
1351
+ for (uint64_t k = 0; k < n_embd; k++) {
1352
+ std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1353
+ }
1354
+ }
1355
+ }
1356
+
1357
+ output_swaps.clear();
1358
+ }
1359
+
1274
1360
  //
1275
1361
  // graph
1276
1362
  //
1277
1363
 
1278
- int32_t llama_context::graph_max_nodes() const {
1279
- return std::max<int32_t>(65536, 5*model.n_tensors());
1364
+ uint32_t llama_context::graph_max_nodes() const {
1365
+ return std::max<uint32_t>(1024u, 8u*model.n_tensors());
1280
1366
  }
1281
1367
 
1282
- ggml_cgraph * llama_context::graph_init() {
1283
- ggml_init_params params = {
1284
- /*.mem_size =*/ buf_compute_meta.size(),
1285
- /*.mem_buffer =*/ buf_compute_meta.data(),
1286
- /*.no_alloc =*/ true,
1287
- };
1288
-
1289
- ctx_compute.reset(ggml_init(params));
1290
-
1291
- return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1368
+ llm_graph_result * llama_context::get_gf_res_reserve() const {
1369
+ return static_cast<llm_graph_result *>(gf_res_reserve.get());
1292
1370
  }
1293
1371
 
1294
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1372
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
1295
1373
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1374
+ GGML_ASSERT(n_outputs >= 1);
1296
1375
 
1297
1376
  if (n_tokens % n_seqs != 0) {
1298
1377
  n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
@@ -1301,6 +1380,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1301
1380
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
1381
  }
1303
1382
 
1383
+ ggml_backend_sched_reset(sched.get());
1384
+
1385
+ // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
1386
+ gf_res_prev->reset();
1387
+
1304
1388
  // store the n_outputs as it is, and restore it afterwards
1305
1389
  // TODO: not sure if needed, might simplify in the future by removing this
1306
1390
  const auto save_n_outputs = this->n_outputs;
@@ -1310,20 +1394,20 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1310
1394
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311
1395
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1312
1396
 
1313
- auto * gf = graph_init();
1314
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1397
+ auto * res = gf_res_reserve.get();
1315
1398
 
1316
- this->n_outputs = save_n_outputs;
1399
+ const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
1317
1400
 
1318
- if (!res) {
1319
- LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320
- return nullptr;
1321
- }
1401
+ res->reset();
1322
1402
 
1323
- ggml_backend_sched_reset(sched.get());
1403
+ auto * gf = model.build_graph(gparams);
1404
+
1405
+ this->n_outputs = save_n_outputs;
1324
1406
 
1325
1407
  // initialize scheduler with the specified graph
1326
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1408
+ if (split_only) {
1409
+ ggml_backend_sched_split_graph(sched.get(), gf);
1410
+ } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1327
1411
  LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1328
1412
  return nullptr;
1329
1413
  }
@@ -1331,28 +1415,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1331
1415
  return gf;
1332
1416
  }
1333
1417
 
1334
- llm_graph_result_ptr llama_context::graph_build(
1335
- ggml_context * ctx,
1336
- ggml_cgraph * gf,
1337
- const llama_ubatch & ubatch,
1338
- llm_graph_type gtype,
1339
- const llama_memory_context_i * mctx) {
1340
- return model.build_graph(
1341
- {
1342
- /*.ctx =*/ ctx,
1343
- /*.arch =*/ model.arch,
1344
- /*.hparams =*/ model.hparams,
1345
- /*.cparams =*/ cparams,
1346
- /*.ubatch =*/ ubatch,
1347
- /*.sched =*/ sched.get(),
1348
- /*.backend_cpu =*/ backend_cpu,
1349
- /*.cvec =*/ &cvec,
1350
- /*.loras =*/ &loras,
1351
- /*.mctx =*/ mctx,
1352
- /*.cross =*/ &cross,
1353
- /*.n_outputs =*/ n_outputs,
1354
- /*.cb =*/ graph_get_cb(),
1355
- }, gf, gtype);
1418
+ llm_graph_params llama_context::graph_params(
1419
+ llm_graph_result * res,
1420
+ const llama_ubatch & ubatch,
1421
+ const llama_memory_context_i * mctx,
1422
+ llm_graph_type gtype) const {
1423
+ return {
1424
+ /*.arch =*/ model.arch,
1425
+ /*.hparams =*/ model.hparams,
1426
+ /*.cparams =*/ cparams,
1427
+ /*.ubatch =*/ ubatch,
1428
+ /*.gtype =*/ gtype,
1429
+ /*.sched =*/ sched.get(),
1430
+ /*.backend_cpu =*/ backend_cpu,
1431
+ /*.cvec =*/ &cvec,
1432
+ /*.loras =*/ &loras,
1433
+ /*.mctx =*/ mctx,
1434
+ /*.cross =*/ &cross,
1435
+ /*.n_outputs =*/ n_outputs,
1436
+ /*.cb =*/ graph_get_cb(),
1437
+ /*.res =*/ res,
1438
+ };
1356
1439
  }
1357
1440
 
1358
1441
  ggml_status llama_context::graph_compute(
@@ -1364,7 +1447,9 @@ ggml_status llama_context::graph_compute(
1364
1447
  if (backend_cpu != nullptr) {
1365
1448
  auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
1366
1449
  auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
1367
- set_threadpool_fn(backend_cpu, tp);
1450
+ if (set_threadpool_fn) {
1451
+ set_threadpool_fn(backend_cpu, tp);
1452
+ }
1368
1453
  }
1369
1454
 
1370
1455
  // set the number of threads for all the backends
@@ -1583,30 +1668,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1583
1668
  }
1584
1669
  }
1585
1670
 
1586
- size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
1671
+ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
1587
1672
  llama_io_write_dummy io;
1588
1673
  try {
1589
- return state_seq_write_data(io, seq_id);
1674
+ return state_seq_write_data(io, seq_id, flags);
1590
1675
  } catch (const std::exception & err) {
1591
1676
  LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1592
1677
  return 0;
1593
1678
  }
1594
1679
  }
1595
1680
 
1596
- size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
1681
+ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
1597
1682
  llama_io_write_buffer io(dst, size);
1598
1683
  try {
1599
- return state_seq_write_data(io, seq_id);
1684
+ return state_seq_write_data(io, seq_id, flags);
1600
1685
  } catch (const std::exception & err) {
1601
1686
  LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1602
1687
  return 0;
1603
1688
  }
1604
1689
  }
1605
1690
 
1606
- size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
1691
+ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
1607
1692
  llama_io_read_buffer io(src, size);
1608
1693
  try {
1609
- return state_seq_read_data(io, seq_id);
1694
+ return state_seq_read_data(io, seq_id, flags);
1610
1695
  } catch (const std::exception & err) {
1611
1696
  LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1612
1697
  return 0;
@@ -1704,7 +1789,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
1704
1789
  {
1705
1790
  const size_t state_size = file.size() - file.tell();
1706
1791
  llama_io_read_file io(&file);
1707
- const size_t nread = state_seq_read_data(io, seq_id);
1792
+ const size_t nread = state_seq_read_data(io, seq_id, 0);
1708
1793
  if (!nread) {
1709
1794
  LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1710
1795
  return 0;
@@ -1728,7 +1813,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
1728
1813
 
1729
1814
  // save the context state using stream saving
1730
1815
  llama_io_write_file io(&file);
1731
- state_seq_write_data(io, seq_id);
1816
+ state_seq_write_data(io, seq_id, 0);
1732
1817
 
1733
1818
  const size_t res = file.tell();
1734
1819
  GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@@ -1803,7 +1888,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1803
1888
  }
1804
1889
 
1805
1890
  if (memory != nullptr) {
1806
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1891
+ LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1807
1892
  memory->state_write(io);
1808
1893
  }
1809
1894
 
@@ -1889,7 +1974,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1889
1974
  }
1890
1975
 
1891
1976
  if (memory) {
1892
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1977
+ LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
1893
1978
 
1894
1979
  memory->state_read(io);
1895
1980
  }
@@ -1897,21 +1982,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1897
1982
  return io.n_bytes();
1898
1983
  }
1899
1984
 
1900
- size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
1985
+ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1901
1986
  GGML_UNUSED(seq_id);
1902
1987
 
1903
1988
  if (memory) {
1904
- memory->state_write(io, seq_id);
1989
+ memory->state_write(io, seq_id, flags);
1905
1990
  }
1906
1991
 
1907
1992
  return io.n_bytes();
1908
1993
  }
1909
1994
 
1910
- size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
1995
+ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1911
1996
  GGML_UNUSED(seq_id);
1912
1997
 
1913
1998
  if (memory) {
1914
- memory->state_read(io, seq_id);
1999
+ memory->state_read(io, seq_id, flags);
1915
2000
  }
1916
2001
 
1917
2002
  return io.n_bytes();
@@ -1930,6 +2015,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
1930
2015
  data.t_eval_ms = 1e-3 * t_eval_us;
1931
2016
  data.n_p_eval = std::max(1, n_p_eval);
1932
2017
  data.n_eval = std::max(1, n_eval);
2018
+ data.n_reused = std::max(0, n_reused);
1933
2019
 
1934
2020
  return data;
1935
2021
  }
@@ -1938,6 +2024,22 @@ void llama_context::perf_reset() {
1938
2024
  t_start_us = ggml_time_us();
1939
2025
  t_eval_us = n_eval = 0;
1940
2026
  t_p_eval_us = n_p_eval = 0;
2027
+ n_reused = 0;
2028
+ }
2029
+
2030
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
2031
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
2032
+ for (const auto & buft_size : model.memory_breakdown()) {
2033
+ ret[buft_size.first].model += buft_size.second;
2034
+ }
2035
+ for (const auto & buft_size : memory->memory_breakdown()) {
2036
+ ret[buft_size.first].context += buft_size.second;
2037
+ }
2038
+ for (const auto & backend_ptr : backends) {
2039
+ ggml_backend_t backend = backend_ptr.get();
2040
+ ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
2041
+ }
2042
+ return ret;
1941
2043
  }
1942
2044
 
1943
2045
  //
@@ -1972,7 +2074,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
1972
2074
  opt_params.opt_period = n_batch / n_ubatch;
1973
2075
  opt_params.get_opt_pars = lopt_params.get_opt_pars;
1974
2076
  opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1975
-
2077
+ opt_params.optimizer = lopt_params.optimizer_type;
1976
2078
  opt_ctx = ggml_opt_init(opt_params);
1977
2079
 
1978
2080
  llama_opt_param_filter param_filter = lopt_params.param_filter;
@@ -2028,7 +2130,7 @@ void llama_context::opt_epoch_iter(
2028
2130
  batch.logits [pos_batch] = true;
2029
2131
  }
2030
2132
 
2031
- if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2133
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2032
2134
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2033
2135
  return;
2034
2136
  }
@@ -2064,8 +2166,13 @@ void llama_context::opt_epoch_iter(
2064
2166
  break;
2065
2167
  }
2066
2168
 
2067
- auto * gf = graph_init();
2068
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2169
+ auto * res = gf_res_prev.get();
2170
+
2171
+ const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2172
+
2173
+ res->reset();
2174
+
2175
+ auto * gf = model.build_graph(gparams);
2069
2176
 
2070
2177
  struct ggml_context * ctx_compute_opt;
2071
2178
  {
@@ -2167,12 +2274,13 @@ llama_context_params llama_context_default_params() {
2167
2274
  /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2168
2275
  /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2169
2276
  /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2277
+ /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
2170
2278
  /*.rope_freq_base =*/ 0.0f,
2171
2279
  /*.rope_freq_scale =*/ 0.0f,
2172
2280
  /*.yarn_ext_factor =*/ -1.0f,
2173
- /*.yarn_attn_factor =*/ 1.0f,
2174
- /*.yarn_beta_fast =*/ 32.0f,
2175
- /*.yarn_beta_slow =*/ 1.0f,
2281
+ /*.yarn_attn_factor =*/ -1.0f,
2282
+ /*.yarn_beta_fast =*/ -1.0f,
2283
+ /*.yarn_beta_slow =*/ -1.0f,
2176
2284
  /*.yarn_orig_ctx =*/ 0,
2177
2285
  /*.defrag_thold =*/ -1.0f,
2178
2286
  /*.cb_eval =*/ nullptr,
@@ -2183,10 +2291,10 @@ llama_context_params llama_context_default_params() {
2183
2291
  /*.abort_callback_data =*/ nullptr,
2184
2292
  /*.embeddings =*/ false,
2185
2293
  /*.offload_kqv =*/ true,
2186
- /*.flash_attn =*/ false,
2187
2294
  /*.no_perf =*/ true,
2188
2295
  /*.op_offload =*/ true,
2189
2296
  /*.swa_full =*/ true,
2297
+ /*.kv_unified =*/ false,
2190
2298
  };
2191
2299
 
2192
2300
  return result;
@@ -2210,12 +2318,30 @@ llama_context * llama_init_from_model(
2210
2318
  return nullptr;
2211
2319
  }
2212
2320
 
2213
- if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2321
+ if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
2214
2322
  LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2215
- params.flash_attn = false;
2323
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2216
2324
  }
2217
2325
 
2218
- if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2326
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2327
+ const uint32_t blck_size = ggml_blck_size(params.type_k);
2328
+ if (model->hparams.n_embd_head_k % blck_size != 0) {
2329
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2330
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2331
+ return nullptr;
2332
+ }
2333
+ }
2334
+
2335
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2336
+ const uint32_t blck_size = ggml_blck_size(params.type_v);
2337
+ if (model->hparams.n_embd_head_v % blck_size != 0) {
2338
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2339
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2340
+ return nullptr;
2341
+ }
2342
+ }
2343
+
2344
+ if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
2219
2345
  LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2220
2346
  return nullptr;
2221
2347
  }
@@ -2261,16 +2387,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2261
2387
  return &ctx->get_model();
2262
2388
  }
2263
2389
 
2264
- // deprecated
2265
- llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2266
- return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2267
- }
2268
-
2269
- // deprecated
2270
- void llama_kv_self_update(llama_context * ctx) {
2271
- ctx->kv_self_update(false);
2272
- }
2273
-
2274
2390
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
2275
2391
  return ctx->pooling_type();
2276
2392
  }
@@ -2488,168 +2604,6 @@ bool llama_memory_can_shift(llama_memory_t mem) {
2488
2604
  return mem->get_can_shift();
2489
2605
  }
2490
2606
 
2491
- //
2492
- // kv cache
2493
- //
2494
-
2495
- // deprecated
2496
- int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2497
- const auto * kv = llama_get_memory(ctx);
2498
- if (!kv) {
2499
- return 0;
2500
- }
2501
-
2502
- int32_t res = 0;
2503
-
2504
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2505
- const llama_pos p0 = kv->seq_pos_min(s);
2506
- const llama_pos p1 = kv->seq_pos_max(s);
2507
-
2508
- if (p0 >= 0) {
2509
- res += (p1 - p0) + 1;
2510
- }
2511
- }
2512
-
2513
- return res;
2514
- }
2515
-
2516
- // deprecated
2517
- // note: this is the same as above - will be removed anyway, so it's ok
2518
- int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2519
- const auto * kv = llama_get_memory(ctx);
2520
- if (!kv) {
2521
- return 0;
2522
- }
2523
-
2524
- int32_t res = 0;
2525
-
2526
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2527
- const llama_pos p0 = kv->seq_pos_min(s);
2528
- const llama_pos p1 = kv->seq_pos_max(s);
2529
-
2530
- if (p0 >= 0) {
2531
- res += (p1 - p0) + 1;
2532
- }
2533
- }
2534
-
2535
- return res;
2536
- }
2537
-
2538
- // deprecated
2539
- void llama_kv_self_clear(llama_context * ctx) {
2540
- auto * kv = llama_get_memory(ctx);
2541
- if (!kv) {
2542
- return;
2543
- }
2544
-
2545
- llama_memory_clear(kv, true);
2546
- }
2547
-
2548
- // deprecated
2549
- bool llama_kv_self_seq_rm(
2550
- llama_context * ctx,
2551
- llama_seq_id seq_id,
2552
- llama_pos p0,
2553
- llama_pos p1) {
2554
- auto * kv = llama_get_memory(ctx);
2555
- if (!kv) {
2556
- return true;
2557
- }
2558
-
2559
- return llama_memory_seq_rm(kv, seq_id, p0, p1);
2560
- }
2561
-
2562
- // deprecated
2563
- void llama_kv_self_seq_cp(
2564
- llama_context * ctx,
2565
- llama_seq_id seq_id_src,
2566
- llama_seq_id seq_id_dst,
2567
- llama_pos p0,
2568
- llama_pos p1) {
2569
- auto * kv = llama_get_memory(ctx);
2570
- if (!kv) {
2571
- return;
2572
- }
2573
-
2574
- llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2575
- }
2576
-
2577
- // deprecated
2578
- void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2579
- auto * kv = llama_get_memory(ctx);
2580
- if (!kv) {
2581
- return;
2582
- }
2583
-
2584
- llama_memory_seq_keep(kv, seq_id);
2585
- }
2586
-
2587
- // deprecated
2588
- void llama_kv_self_seq_add(
2589
- llama_context * ctx,
2590
- llama_seq_id seq_id,
2591
- llama_pos p0,
2592
- llama_pos p1,
2593
- llama_pos delta) {
2594
- auto * kv = llama_get_memory(ctx);
2595
- if (!kv) {
2596
- return;
2597
- }
2598
-
2599
- llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2600
- }
2601
-
2602
- // deprecated
2603
- void llama_kv_self_seq_div(
2604
- llama_context * ctx,
2605
- llama_seq_id seq_id,
2606
- llama_pos p0,
2607
- llama_pos p1,
2608
- int d) {
2609
- auto * kv = llama_get_memory(ctx);
2610
- if (!kv) {
2611
- return;
2612
- }
2613
-
2614
- llama_memory_seq_div(kv, seq_id, p0, p1, d);
2615
- }
2616
-
2617
- // deprecated
2618
- llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2619
- auto * kv = llama_get_memory(ctx);
2620
- if (!kv) {
2621
- return -1;
2622
- }
2623
-
2624
- return llama_memory_seq_pos_min(kv, seq_id);
2625
- }
2626
-
2627
- // deprecated
2628
- llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2629
- auto * kv = llama_get_memory(ctx);
2630
- if (!kv) {
2631
- return -1;
2632
- }
2633
-
2634
- return llama_memory_seq_pos_max(kv, seq_id);
2635
- }
2636
-
2637
- // deprecated
2638
- void llama_kv_self_defrag(llama_context * ctx) {
2639
- // force defrag
2640
- ctx->kv_self_defrag_sched();
2641
- }
2642
-
2643
- // deprecated
2644
- bool llama_kv_self_can_shift(const llama_context * ctx) {
2645
- auto * kv = llama_get_memory(ctx);
2646
- if (!kv) {
2647
- return false;
2648
- }
2649
-
2650
- return llama_memory_can_shift(kv);
2651
- }
2652
-
2653
2607
  // llama state API
2654
2608
 
2655
2609
  // deprecated
@@ -2719,19 +2673,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
2719
2673
  }
2720
2674
 
2721
2675
  size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2722
- return ctx->state_seq_get_size(seq_id);
2676
+ return llama_state_seq_get_size_ext(ctx, seq_id, 0);
2723
2677
  }
2724
2678
 
2725
2679
  size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2680
+ return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
2681
+ }
2682
+
2683
+ size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2684
+ return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
2685
+ }
2686
+
2687
+ size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
2688
+ return ctx->state_seq_get_size(seq_id, flags);
2689
+ }
2690
+
2691
+ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2726
2692
  ctx->synchronize();
2727
2693
 
2728
- return ctx->state_seq_get_data(seq_id, dst, size);
2694
+ return ctx->state_seq_get_data(seq_id, dst, size, flags);
2729
2695
  }
2730
2696
 
2731
- size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2697
+ size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2732
2698
  ctx->synchronize();
2733
2699
 
2734
- return ctx->state_seq_set_data(seq_id, src, size);
2700
+ return ctx->state_seq_set_data(seq_id, src, size, flags);
2735
2701
  }
2736
2702
 
2737
2703
  size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
@@ -2807,12 +2773,149 @@ void llama_perf_context_print(const llama_context * ctx) {
2807
2773
  LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2808
2774
  __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2809
2775
  LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2776
+ LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
2810
2777
  }
2811
2778
 
2812
2779
  void llama_perf_context_reset(llama_context * ctx) {
2813
2780
  ctx->perf_reset();
2814
2781
  }
2815
2782
 
2783
+ void llama_memory_breakdown_print(const struct llama_context * ctx) {
2784
+ const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
2785
+
2786
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
2787
+
2788
+ std::vector<std::array<std::string, 9>> table_data;
2789
+ table_data.reserve(devices.size());
2790
+ const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n";
2791
+ const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n";
2792
+ const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n";
2793
+
2794
+ table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"});
2795
+
2796
+ constexpr size_t MiB = 1024 * 1024;
2797
+ const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "};
2798
+
2799
+ // track seen buffer types to avoid double counting:
2800
+ std::set<ggml_backend_buffer_type_t> seen_buffer_types;
2801
+
2802
+ // accumulative memory breakdown for each device and for host:
2803
+ std::vector<llama_memory_breakdown_data> mb_dev(devices.size());
2804
+ llama_memory_breakdown_data mb_host;
2805
+
2806
+ for (const auto & buft_mb : memory_breakdown) {
2807
+ ggml_backend_buffer_type_t buft = buft_mb.first;
2808
+ const llama_memory_breakdown_data & mb = buft_mb.second;
2809
+ if (ggml_backend_buft_is_host(buft)) {
2810
+ mb_host.model += mb.model;
2811
+ mb_host.context += mb.context;
2812
+ mb_host.compute += mb.compute;
2813
+ seen_buffer_types.insert(buft);
2814
+ continue;
2815
+ }
2816
+ ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
2817
+ if (dev) {
2818
+ int i_dev = -1;
2819
+ for (size_t i = 0; i < devices.size(); i++) {
2820
+ if (devices[i] == dev) {
2821
+ i_dev = i;
2822
+ break;
2823
+ }
2824
+ }
2825
+ if (i_dev != -1) {
2826
+ mb_dev[i_dev].model += mb.model;
2827
+ mb_dev[i_dev].context += mb.context;
2828
+ mb_dev[i_dev].compute += mb.compute;
2829
+ seen_buffer_types.insert(buft);
2830
+ continue;
2831
+ }
2832
+ }
2833
+ }
2834
+
2835
+ // print memory breakdown for each device:
2836
+ for (size_t i = 0; i < devices.size(); i++) {
2837
+ ggml_backend_dev_t dev = devices[i];
2838
+ llama_memory_breakdown_data mb = mb_dev[i];
2839
+
2840
+ const std::string name = ggml_backend_dev_name(dev);
2841
+ std::string desc = ggml_backend_dev_description(dev);
2842
+ for (const std::string & prefix : desc_prefixes_strip) {
2843
+ if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) {
2844
+ desc = desc.substr(prefix.length());
2845
+ }
2846
+ }
2847
+
2848
+ size_t free, total;
2849
+ ggml_backend_dev_memory(dev, &free, &total);
2850
+
2851
+ const size_t self = mb.model + mb.context + mb.compute;
2852
+ const size_t unaccounted = total - self - free;
2853
+
2854
+ table_data.push_back({
2855
+ template_gpu,
2856
+ " - " + name + " (" + desc + ")",
2857
+ std::to_string(total / MiB),
2858
+ std::to_string(free / MiB),
2859
+ std::to_string(self / MiB),
2860
+ std::to_string(mb.model / MiB),
2861
+ std::to_string(mb.context / MiB),
2862
+ std::to_string(mb.compute / MiB),
2863
+ std::to_string(unaccounted / MiB)});
2864
+ }
2865
+
2866
+ // print memory breakdown for host:
2867
+ {
2868
+ const size_t self = mb_host.model + mb_host.context + mb_host.compute;
2869
+ table_data.push_back({
2870
+ template_other,
2871
+ " - Host",
2872
+ "", // total
2873
+ "", // free
2874
+ std::to_string(self / MiB),
2875
+ std::to_string(mb_host.model / MiB),
2876
+ std::to_string(mb_host.context / MiB),
2877
+ std::to_string(mb_host.compute / MiB),
2878
+ ""}); // unaccounted
2879
+ }
2880
+
2881
+ // print memory breakdown for all remaining buffer types:
2882
+ for (const auto & buft_mb : memory_breakdown) {
2883
+ ggml_backend_buffer_type_t buft = buft_mb.first;
2884
+ const llama_memory_breakdown_data & mb = buft_mb.second;
2885
+ if (seen_buffer_types.count(buft) == 1) {
2886
+ continue;
2887
+ }
2888
+ const std::string name = ggml_backend_buft_name(buft);
2889
+ const size_t self = mb.model + mb.context + mb.compute;
2890
+ table_data.push_back({
2891
+ template_other,
2892
+ " - " + name,
2893
+ "", // total
2894
+ "", // free
2895
+ std::to_string(self / MiB),
2896
+ std::to_string(mb.model / MiB),
2897
+ std::to_string(mb.context / MiB),
2898
+ std::to_string(mb.compute / MiB),
2899
+ ""}); // unaccounted
2900
+ seen_buffer_types.insert(buft);
2901
+ }
2902
+
2903
+ for (size_t j = 1; j < table_data[0].size(); j++) {
2904
+ size_t max_len = 0;
2905
+ for (const auto & td : table_data) {
2906
+ max_len = std::max(max_len, td[j].length());
2907
+ }
2908
+ for (auto & td : table_data) {
2909
+ td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' ');
2910
+ }
2911
+ }
2912
+ for (const auto & td : table_data) {
2913
+ LLAMA_LOG_INFO(td[0].c_str(),
2914
+ __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(),
2915
+ td[6].c_str(), td[7].c_str(), td[8].c_str());
2916
+ }
2917
+ }
2918
+
2816
2919
  //
2817
2920
  // training
2818
2921
  //