whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,14 +1,16 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-io.h"
6
+ #include "llama-memory.h"
5
7
  #include "llama-mmap.h"
6
8
  #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
9
 
10
+ #include <cinttypes>
9
11
  #include <cstring>
12
+ #include <limits>
10
13
  #include <stdexcept>
11
- #include <cinttypes>
12
14
 
13
15
  //
14
16
  // llama_context
@@ -17,7 +19,8 @@
17
19
  llama_context::llama_context(
18
20
  const llama_model & model,
19
21
  llama_context_params params) :
20
- model(model) {
22
+ model(model),
23
+ balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
21
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
22
25
 
23
26
  t_start_us = model.t_start_us;
@@ -26,20 +29,18 @@ llama_context::llama_context(
26
29
  const auto & hparams = model.hparams;
27
30
 
28
31
  cparams.n_seq_max = std::max(1u, params.n_seq_max);
29
- if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30
- throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
31
34
  }
32
35
 
33
36
  cparams.n_threads = params.n_threads;
34
37
  cparams.n_threads_batch = params.n_threads_batch;
35
- cparams.yarn_ext_factor = params.yarn_ext_factor;
36
- cparams.yarn_attn_factor = params.yarn_attn_factor;
37
- cparams.yarn_beta_fast = params.yarn_beta_fast;
38
- cparams.yarn_beta_slow = params.yarn_beta_slow;
39
- cparams.defrag_thold = params.defrag_thold;
38
+ cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
39
+ cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
40
+ cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
41
+ cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
40
42
  cparams.embeddings = params.embeddings;
41
43
  cparams.offload_kqv = params.offload_kqv;
42
- cparams.flash_attn = params.flash_attn;
43
44
  cparams.no_perf = params.no_perf;
44
45
  cparams.pooling_type = params.pooling_type;
45
46
  cparams.warmup = false;
@@ -84,21 +85,32 @@ llama_context::llama_context(
84
85
  cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
85
86
  }
86
87
 
88
+ cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
89
+
87
90
  // with causal attention, the batch size is limited by the context size
88
91
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
89
92
 
90
93
  // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
91
94
  // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
92
95
  // ref: https://github.com/ggerganov/llama.cpp/pull/5021
93
- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
96
+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
94
97
  if (cparams.n_batch < GGML_KQ_MASK_PAD) {
95
98
  LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
96
99
  cparams.n_batch = GGML_KQ_MASK_PAD;
97
100
  }
98
-
99
101
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
100
102
 
101
103
  cparams.op_offload = params.op_offload;
104
+ cparams.kv_unified = params.kv_unified;
105
+
106
+ {
107
+ const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
108
+ graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
109
+
110
+ if (graph_reuse_disable) {
111
+ LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__);
112
+ }
113
+ }
102
114
 
103
115
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
104
116
 
@@ -108,7 +120,8 @@ llama_context::llama_context(
108
120
  LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
109
121
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
110
122
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
111
- LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
123
+ LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
124
+ LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
112
125
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
113
126
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
114
127
 
@@ -168,7 +181,7 @@ llama_context::llama_context(
168
181
  // graph outputs buffer
169
182
  {
170
183
  // resized during inference when a batch uses more outputs
171
- if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
184
+ if (output_reserve(params.n_seq_max) < params.n_seq_max) {
172
185
  throw std::runtime_error("failed to reserve initial output buffer");
173
186
  }
174
187
 
@@ -219,8 +232,8 @@ llama_context::llama_context(
219
232
 
220
233
  LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
221
234
 
222
- // buffer used to store the computation graph and the tensor meta data
223
- buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
235
+ gf_res_prev.reset(new llm_graph_result(max_nodes));
236
+ gf_res_reserve.reset(new llm_graph_result(max_nodes));
224
237
 
225
238
  // TODO: move these checks to ggml_backend_sched
226
239
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -257,45 +270,79 @@ llama_context::llama_context(
257
270
  }
258
271
  }
259
272
 
260
- // reserve worst-case graph
261
- if (!hparams.vocab_only && memory) {
262
- const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
263
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
273
+ if (!hparams.vocab_only) {
274
+ llama_memory_context_ptr mctx;
275
+ if (memory) {
276
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
277
+ mctx = memory->init_full();
278
+ if (!mctx) {
279
+ throw std::runtime_error("failed to initialize memory module");
280
+ }
281
+ }
282
+
283
+ cross.v_embd.clear();
264
284
 
265
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
285
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
286
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
266
287
 
267
- // restore later
268
- // TODO: something cleaner
269
- const auto n_outputs_save = n_outputs;
288
+ // avoid reserving graphs with zero outputs - assume one output per sequence
289
+ n_outputs = n_seqs;
270
290
 
271
291
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
272
292
 
293
+ // resolve automatic Flash Attention use
294
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
295
+ auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
296
+ if (!gf) {
297
+ throw std::runtime_error("failed to split graph for Flash Attention check");
298
+ }
299
+
300
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
301
+ bool fa_device_mismatch = false;
302
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
303
+ ggml_tensor * n = ggml_graph_node(gf, i);
304
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
305
+ continue;
306
+ }
307
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
308
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
309
+
310
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
311
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
312
+ const int il = std::stoi(n->name + prefix_len);
313
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
314
+ if (device_fa != device_kv) {
315
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
316
+ "is assigned to device %s (usually due to missing support)\n",
317
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
318
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
319
+ fa_device_mismatch = true;
320
+ break;
321
+ }
322
+ }
323
+ if (fa_device_mismatch) {
324
+ cparams.flash_attn = false;
325
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
326
+ if (ggml_is_quantized(params.type_v)) {
327
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
328
+ }
329
+ } else {
330
+ cparams.flash_attn = true;
331
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
332
+ }
333
+ }
334
+
335
+ // reserve worst-case graph
273
336
  int n_splits_pp = -1;
274
337
  int n_nodes_pp = -1;
275
338
 
276
339
  int n_splits_tg = -1;
277
340
  int n_nodes_tg = -1;
278
341
 
279
- // simulate full KV cache
280
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
-
282
- kv_self->set_full();
283
-
284
- cross.v_embd.clear();
285
-
286
- // reserve pp graph first so that buffers are only allocated once
342
+ // reserve pp (prompt processing) graph first so that buffers are only allocated once
287
343
  {
288
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
289
-
290
- // max number of outputs
291
- n_outputs = ubatch_pp.n_tokens;
292
-
293
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
294
-
295
- auto * gf = graph_init();
296
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
297
-
298
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
344
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
345
+ if (!gf) {
299
346
  throw std::runtime_error("failed to allocate compute pp buffers");
300
347
  }
301
348
 
@@ -303,18 +350,10 @@ llama_context::llama_context(
303
350
  n_nodes_pp = ggml_graph_n_nodes(gf);
304
351
  }
305
352
 
306
- // reserve with tg graph to get the number of splits and nodes
353
+ // reserve with tg (token generation) graph to get the number of splits and nodes
307
354
  {
308
- llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
309
-
310
- n_outputs = ubatch_tg.n_tokens;
311
-
312
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
313
-
314
- auto * gf = graph_init();
315
- graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
316
-
317
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
355
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
356
+ if (!gf) {
318
357
  throw std::runtime_error("failed to allocate compute tg buffers");
319
358
  }
320
359
 
@@ -324,22 +363,16 @@ llama_context::llama_context(
324
363
 
325
364
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
326
365
  {
327
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
328
-
329
- n_outputs = ubatch_pp.n_tokens;
330
-
331
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
332
-
333
- auto * gf = graph_init();
334
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
335
-
336
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
366
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
367
+ //
368
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
369
+ //
370
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
371
+ if (!gf) {
337
372
  throw std::runtime_error("failed to allocate compute pp buffers");
338
373
  }
339
374
  }
340
375
 
341
- n_outputs = n_outputs_save;
342
-
343
376
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
344
377
  ggml_backend_t backend = backend_ptrs[i];
345
378
  ggml_backend_buffer_type_t buft = backend_buft[i];
@@ -411,10 +444,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
411
444
  return sched.get();
412
445
  }
413
446
 
414
- ggml_context * llama_context::get_ctx_compute() const {
415
- return ctx_compute.get();
416
- }
417
-
418
447
  uint32_t llama_context::n_ctx() const {
419
448
  return cparams.n_ctx;
420
449
  }
@@ -443,46 +472,62 @@ uint32_t llama_context::n_threads_batch() const {
443
472
  return cparams.n_threads_batch;
444
473
  }
445
474
 
446
- llama_kv_cache * llama_context::get_kv_self() {
447
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
448
- return kv_self;
449
- }
450
-
451
- const llama_kv_cache * llama_context::get_kv_self() const {
452
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
453
- return kv_self;
475
+ llama_memory_t llama_context::get_memory() const {
476
+ return memory.get();
454
477
  }
455
478
 
456
- void llama_context::kv_self_update() {
457
- bool need_reserve = false;
458
-
459
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
460
-
461
- need_reserve = kv_self->update(*this);
479
+ bool llama_context::memory_update(bool optimize) {
480
+ if (!memory) {
481
+ return false;
482
+ }
462
483
 
463
- // reserve a worst case graph if needed
464
- if (need_reserve) {
465
- LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
484
+ {
485
+ const auto mctx = memory->init_update(this, optimize);
486
+ switch (mctx->get_status()) {
487
+ case LLAMA_MEMORY_STATUS_SUCCESS:
488
+ {
489
+ // noop
490
+ } break;
491
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
492
+ {
493
+ // no updates need to be performed
494
+ return false;
495
+ }
496
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
497
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
498
+ {
499
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
500
+ return false;
501
+ }
502
+ }
466
503
 
467
- // build worst-case graph
468
- uint32_t n_seqs = 1; // TODO: worst-case number of sequences
469
- uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
504
+ // reset the previous graph result to make sure that it won't be reused
505
+ // TODO: change the mctx->apply() to return information if a graph reserve is needed
506
+ // reset the graph result only if the memory module did reset the scheduler
507
+ gf_res_prev->reset();
470
508
 
471
- // simulate full KV cache
472
- kv_self->set_full();
509
+ if (!mctx->apply()) {
510
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
511
+ }
512
+ }
473
513
 
474
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
475
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
514
+ // if the memory module did any computation, we have to reserve a new worst-case graph
515
+ {
516
+ const auto mctx = memory->init_full();
517
+ if (!mctx) {
518
+ throw std::runtime_error("failed to initialize memory context");
519
+ }
476
520
 
477
- auto * gf = graph_init();
478
- graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
521
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
522
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
479
523
 
480
- // initialize scheduler with the worst-case graph
481
- ggml_backend_sched_reset(sched.get());
482
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
483
- LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
524
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
525
+ if (!gf) {
526
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484
527
  }
485
528
  }
529
+
530
+ return true;
486
531
  }
487
532
 
488
533
  enum llama_pooling_type llama_context::pooling_type() const {
@@ -490,11 +535,15 @@ enum llama_pooling_type llama_context::pooling_type() const {
490
535
  }
491
536
 
492
537
  float * llama_context::get_logits() {
538
+ output_reorder();
539
+
493
540
  return logits;
494
541
  }
495
542
 
496
543
  float * llama_context::get_logits_ith(int32_t i) {
497
- int32_t j = -1;
544
+ int64_t j = -1;
545
+
546
+ output_reorder();
498
547
 
499
548
  try {
500
549
  if (logits == nullptr) {
@@ -517,7 +566,7 @@ float * llama_context::get_logits_ith(int32_t i) {
517
566
  }
518
567
  if (j >= n_outputs) {
519
568
  // This should not happen
520
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
569
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
521
570
  }
522
571
 
523
572
  return logits + j*model.vocab.n_tokens();
@@ -532,11 +581,15 @@ float * llama_context::get_logits_ith(int32_t i) {
532
581
  }
533
582
 
534
583
  float * llama_context::get_embeddings() {
584
+ output_reorder();
585
+
535
586
  return embd;
536
587
  }
537
588
 
538
589
  float * llama_context::get_embeddings_ith(int32_t i) {
539
- int32_t j = -1;
590
+ int64_t j = -1;
591
+
592
+ output_reorder();
540
593
 
541
594
  try {
542
595
  if (embd == nullptr) {
@@ -559,7 +612,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
559
612
  }
560
613
  if (j >= n_outputs) {
561
614
  // This should not happen
562
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
615
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
563
616
  }
564
617
 
565
618
  return embd + j*model.hparams.n_embd;
@@ -676,72 +729,119 @@ bool llama_context::apply_adapter_cvec(
676
729
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
677
730
  }
678
731
 
679
- int llama_context::encode(llama_batch & inp_batch) {
680
- if (inp_batch.n_tokens == 0) {
681
- LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
682
- return -1;
732
+ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
733
+ if (mctx && !mctx->apply()) {
734
+ LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
735
+ ret = GGML_STATUS_FAILED;
736
+ return nullptr;
683
737
  }
684
738
 
685
- // temporary allocate memory for the input batch if needed
686
- // note: during encode, we always pass the full sequence starting from pos = 0
687
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
739
+ auto * res = gf_res_prev.get();
740
+ auto * gf = res->get_gf();
688
741
 
689
- const llama_batch & batch = batch_allocr.batch;
690
- const int32_t n_tokens = batch.n_tokens;
742
+ // the new graph parameters
743
+ // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
744
+ const auto gparams = graph_params(res, ubatch, mctx, gtype);
691
745
 
692
- const auto & hparams = model.hparams;
746
+ if (!graph_reuse_disable && res->can_reuse(gparams)) {
747
+ //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
693
748
 
694
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
749
+ n_reused++;
750
+ } else {
751
+ res->reset();
695
752
 
696
- // TODO: move the validation to the llama_batch_allocr
697
- if (batch.token) {
698
- for (int32_t i = 0; i < n_tokens; ++i) {
699
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
700
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
701
- return -1;
702
- }
753
+ ggml_backend_sched_reset(sched.get());
754
+ ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
703
755
 
704
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
706
- throw -1;
707
- }
756
+ //const auto t_start_us = ggml_time_us();
757
+
758
+ gf = model.build_graph(gparams);
759
+
760
+ //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
761
+
762
+ if (!gf) {
763
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
764
+ ret = GGML_STATUS_FAILED;
765
+ return nullptr;
766
+ }
767
+
768
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
769
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
770
+ ret = GGML_STATUS_ALLOC_FAILED;
771
+ return nullptr;
708
772
  }
709
773
  }
710
774
 
775
+ // set the input data for the input tensors
776
+ {
777
+ //const auto t_start_us = ggml_time_us();
778
+
779
+ res->set_inputs(&ubatch);
780
+
781
+ //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
782
+ }
783
+
784
+ const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
785
+ if (status != GGML_STATUS_SUCCESS) {
786
+ LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
787
+ ret = status;
788
+ return nullptr;
789
+ }
790
+
791
+ ret = GGML_STATUS_SUCCESS;
792
+
793
+ return res;
794
+ }
795
+
796
+ int llama_context::encode(const llama_batch & batch_inp) {
797
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
798
+
799
+ if (batch_inp.n_tokens == 0) {
800
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
801
+ return -1;
802
+ }
803
+
804
+ const auto & hparams = model.hparams;
805
+
806
+ const int64_t n_embd = hparams.n_embd;
807
+ const int64_t n_vocab = model.vocab.n_tokens();
808
+
809
+ // note: during encode, we always pass the full sequence starting from pos = 0
810
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
811
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
812
+ return -1;
813
+ }
814
+
815
+ const uint32_t n_tokens = balloc->get_n_tokens();
816
+
817
+ // [TAG_NO_CACHE_PAD]
818
+ // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
819
+ const llama_ubatch ubatch = balloc->split_simple(n_tokens);
820
+
711
821
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
712
- GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
822
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
713
823
 
714
824
  if (t_compute_start_us == 0) {
715
825
  t_compute_start_us = ggml_time_us();
716
826
  }
717
827
 
828
+ // TODO: this clear of the buffer can easily be forgotten - need something better
718
829
  embd_seq.clear();
719
830
 
720
831
  n_queued_tokens += n_tokens;
721
832
 
722
- const int64_t n_embd = hparams.n_embd;
723
-
724
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
725
-
726
- const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
727
-
728
833
  // reserve output buffer
729
834
  if (output_reserve(n_tokens) < n_tokens) {
730
835
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
731
836
  return -2;
732
837
  };
733
838
 
734
- for (int32_t i = 0; i < n_tokens; ++i) {
839
+ for (uint32_t i = 0; i < n_tokens; ++i) {
735
840
  output_ids[i] = i;
736
841
  }
737
842
 
738
843
  n_outputs = n_tokens;
739
844
 
740
- //batch_manager->prepare(ubatch);
741
-
742
- ggml_backend_sched_reset(sched.get());
743
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
744
-
745
845
  const auto causal_attn_org = cparams.causal_attn;
746
846
 
747
847
  // always use non-causal attention for encoder graphs
@@ -749,32 +849,34 @@ int llama_context::encode(llama_batch & inp_batch) {
749
849
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
750
850
  cparams.causal_attn = false;
751
851
 
752
- auto * gf = graph_init();
753
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
754
-
755
- ggml_backend_sched_alloc_graph(sched.get(), gf);
756
-
757
- res->set_inputs(&ubatch);
852
+ ggml_status status;
853
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
758
854
 
759
855
  cparams.causal_attn = causal_attn_org;
760
856
 
761
- const auto compute_status = graph_compute(gf, n_tokens > 1);
762
- switch (compute_status) {
763
- case GGML_STATUS_SUCCESS:
764
- break;
765
- case GGML_STATUS_ABORTED:
766
- return 2;
767
- case GGML_STATUS_ALLOC_FAILED:
768
- return -2;
769
- case GGML_STATUS_FAILED:
770
- default:
771
- return -3;
857
+ if (!res) {
858
+ switch (status) {
859
+ case GGML_STATUS_ABORTED: return 2;
860
+ case GGML_STATUS_ALLOC_FAILED: return -2;
861
+ case GGML_STATUS_FAILED: return -3;
862
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
863
+ }
772
864
  }
773
865
 
866
+ auto * t_logits = res->get_logits();
774
867
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
775
868
 
869
+ // extract logits
870
+ if (logits && t_logits) {
871
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
872
+ GGML_ASSERT(backend_res != nullptr);
873
+ GGML_ASSERT(logits != nullptr);
874
+
875
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
876
+ }
877
+
776
878
  // extract embeddings
777
- if (t_embd) {
879
+ if (embd && t_embd) {
778
880
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
779
881
  GGML_ASSERT(backend_embd != nullptr);
780
882
 
@@ -793,31 +895,28 @@ int llama_context::encode(llama_batch & inp_batch) {
793
895
  {
794
896
  // extract sequence embeddings
795
897
  auto & embd_seq_out = embd_seq;
796
- embd_seq_out.clear();
797
898
 
798
- GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
899
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
900
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
901
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
799
902
 
800
- for (int32_t i = 0; i < n_tokens; i++) {
801
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
802
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
803
- continue;
804
- }
805
903
  embd_seq_out[seq_id].resize(n_embd);
806
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
904
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
807
905
  }
808
906
  } break;
809
907
  case LLAMA_POOLING_TYPE_RANK:
810
908
  {
811
- // extract the rerank score - a single float per sequence
909
+ // extract the rerank score - n_cls_out floats per sequence
812
910
  auto & embd_seq_out = embd_seq;
813
911
 
814
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
815
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
816
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
817
- continue;
818
- }
819
- embd_seq_out[seq_id].resize(1);
820
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
912
+ const uint32_t n_cls_out = hparams.n_cls_out;
913
+
914
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
915
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
916
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
917
+
918
+ embd_seq_out[seq_id].resize(n_cls_out);
919
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
821
920
  }
822
921
  } break;
823
922
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -827,10 +926,6 @@ int llama_context::encode(llama_batch & inp_batch) {
827
926
  }
828
927
  }
829
928
 
830
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
831
- // overlap with device computation.
832
- ggml_backend_sched_reset(sched.get());
833
-
834
929
  // TODO: hacky solution
835
930
  if (model.arch == LLM_ARCH_T5 && t_embd) {
836
931
  //cross.t_embd = t_embd;
@@ -842,12 +937,16 @@ int llama_context::encode(llama_batch & inp_batch) {
842
937
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
843
938
  memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
844
939
 
940
+ const auto & batch = balloc->get_batch();
941
+
845
942
  // remember the sequence ids used during the encoding - needed for cross attention later
846
943
  cross.seq_ids_enc.resize(n_tokens);
847
- for (int32_t i = 0; i < n_tokens; i++) {
944
+ for (uint32_t i = 0; i < n_tokens; i++) {
848
945
  cross.seq_ids_enc[i].clear();
849
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
850
- llama_seq_id seq_id = ubatch.seq_id[i][s];
946
+
947
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
948
+ const llama_seq_id seq_id = batch.seq_id[i][s];
949
+
851
950
  cross.seq_ids_enc[i].insert(seq_id);
852
951
  }
853
952
  }
@@ -856,55 +955,42 @@ int llama_context::encode(llama_batch & inp_batch) {
856
955
  return 0;
857
956
  }
858
957
 
859
- int llama_context::decode(llama_batch & inp_batch) {
958
+ int llama_context::decode(const llama_batch & batch_inp) {
959
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
960
+
860
961
  if (!memory) {
861
962
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
862
- return encode(inp_batch);
963
+ return encode(batch_inp);
863
964
  }
864
965
 
865
- if (inp_batch.n_tokens == 0) {
966
+ if (batch_inp.n_tokens == 0) {
866
967
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
867
968
  return -1;
868
969
  }
869
970
 
870
- if (!inp_batch.pos) {
871
- if (inp_batch.seq_id) {
872
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
873
- return -1;
874
- }
875
- }
876
-
877
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
878
-
879
- // temporary allocate memory for the input batch if needed
880
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
881
-
882
- const llama_batch & batch = batch_allocr.batch;
883
-
884
971
  const auto & vocab = model.vocab;
885
972
  const auto & hparams = model.hparams;
886
973
 
887
- const int32_t n_vocab = vocab.n_tokens();
888
-
889
- const int64_t n_tokens_all = batch.n_tokens;
890
- const int64_t n_embd = hparams.n_embd;
974
+ const int64_t n_vocab = vocab.n_tokens();
975
+ const int64_t n_embd = hparams.n_embd;
891
976
 
892
- llama_kv_cache_guard kv_guard(kv_self);
977
+ // when computing embeddings, all tokens are output
978
+ const bool output_all = cparams.embeddings;
893
979
 
894
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
980
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
981
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
982
+ return -1;
983
+ }
895
984
 
896
- // TODO: move the validation to the llama_batch_allocr
897
- if (batch.token) {
898
- for (int64_t i = 0; i < n_tokens_all; ++i) {
899
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
900
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
901
- return -1;
902
- }
985
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
986
+ const uint32_t n_outputs_all = balloc->get_n_outputs();
903
987
 
904
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905
- LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
906
- return -1;
907
- }
988
+ if (output_all) {
989
+ // require that all tokens are output
990
+ if (n_outputs_all != n_tokens_all) {
991
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
992
+ __func__, n_outputs_all, n_tokens_all);
993
+ return -1;
908
994
  }
909
995
  }
910
996
 
@@ -917,49 +1003,78 @@ int llama_context::decode(llama_batch & inp_batch) {
917
1003
  }
918
1004
  n_queued_tokens += n_tokens_all;
919
1005
 
920
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
921
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
922
-
1006
+ // TODO: this clear of the buffer can easily be forgotten - need something better
923
1007
  embd_seq.clear();
1008
+ output_swaps.clear();
924
1009
 
925
- int64_t n_outputs_all = 0;
1010
+ bool did_optimize = false;
926
1011
 
927
- // count outputs
928
- if (batch.logits && !embd_pooled) {
929
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
930
- n_outputs_all += batch.logits[i] != 0;
1012
+ // handle any pending shifts/copies
1013
+ memory_update(false);
1014
+
1015
+ llama_memory_context_ptr mctx;
1016
+
1017
+ while (true) {
1018
+ mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
1019
+ if (!mctx) {
1020
+ return -2;
931
1021
  }
932
- } else if (embd_pooled) {
933
- n_outputs_all = n_tokens_all;
934
- } else {
935
- // keep last output only
936
- n_outputs_all = 1;
937
- }
938
1022
 
939
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
1023
+ switch (mctx->get_status()) {
1024
+ case LLAMA_MEMORY_STATUS_SUCCESS:
1025
+ {
1026
+ } break;
1027
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
1028
+ {
1029
+ LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
1030
+
1031
+ return -2;
1032
+ }
1033
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
1034
+ {
1035
+ if (!did_optimize) {
1036
+ did_optimize = true;
1037
+
1038
+ if (memory_update(true)) {
1039
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
1040
+
1041
+ continue;
1042
+ }
1043
+ }
1044
+
1045
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
1046
+
1047
+ return 1;
1048
+ }
1049
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
1050
+ {
1051
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
1052
+
1053
+ return -2;
1054
+ }
1055
+ }
1056
+
1057
+ break;
1058
+ }
940
1059
 
941
1060
  // reserve output buffer
942
1061
  if (output_reserve(n_outputs_all) < n_outputs_all) {
943
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1062
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
944
1063
  return -2;
945
1064
  };
946
1065
 
947
- // handle any pending defrags/shifts
948
- kv_self_update();
949
-
950
1066
  int64_t n_outputs_prev = 0;
951
1067
 
952
- while (sbatch.n_tokens > 0) {
953
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1068
+ do {
1069
+ const auto & ubatch = mctx->get_ubatch();
954
1070
 
955
- // count the outputs in this u_batch
1071
+ // count the outputs in this ubatch
956
1072
  {
957
1073
  int32_t n_outputs_new = 0;
958
1074
 
959
1075
  if (n_outputs_all == n_tokens_all) {
960
1076
  n_outputs_new = ubatch.n_tokens;
961
1077
  } else {
962
- GGML_ASSERT(ubatch.output);
963
1078
  for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
964
1079
  n_outputs_new += (int32_t) (ubatch.output[i] != 0);
965
1080
  }
@@ -969,33 +1084,37 @@ int llama_context::decode(llama_batch & inp_batch) {
969
1084
  n_outputs = n_outputs_new;
970
1085
  }
971
1086
 
972
- // find KV slot
973
- if (!kv_self->find_slot(ubatch)) {
974
- return 1;
975
- }
1087
+ ggml_status status;
1088
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
976
1089
 
977
- ggml_backend_sched_reset(sched.get());
978
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1090
+ if (!res) {
1091
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
1092
+ llama_pos pos_min[LLAMA_MAX_SEQ];
1093
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1094
+ pos_min[s] = std::numeric_limits<llama_pos>::max();
1095
+ }
979
1096
 
980
- auto * gf = graph_init();
981
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1097
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1098
+ const auto & seq_id = ubatch.seq_id[i][0];
982
1099
 
983
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1100
+ pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1101
+ }
984
1102
 
985
- ggml_backend_sched_alloc_graph(sched.get(), gf);
1103
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1104
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1105
+ continue;
1106
+ }
986
1107
 
987
- res->set_inputs(&ubatch);
1108
+ LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
988
1109
 
989
- const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
990
- if (compute_status != GGML_STATUS_SUCCESS) {
991
- switch (compute_status) {
992
- case GGML_STATUS_ABORTED:
993
- return 2;
994
- case GGML_STATUS_ALLOC_FAILED:
995
- return -2;
996
- case GGML_STATUS_FAILED:
997
- default:
998
- return -3;
1110
+ memory->seq_rm(s, pos_min[s], -1);
1111
+ }
1112
+
1113
+ switch (status) {
1114
+ case GGML_STATUS_ABORTED: return 2;
1115
+ case GGML_STATUS_ALLOC_FAILED: return -2;
1116
+ case GGML_STATUS_FAILED: return -3;
1117
+ case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
999
1118
  }
1000
1119
  }
1001
1120
 
@@ -1004,7 +1123,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1004
1123
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1005
1124
  //}
1006
1125
 
1007
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1126
+ auto * t_logits = res->get_logits();
1008
1127
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1009
1128
 
1010
1129
  if (t_embd && res->get_embd_pooled()) {
@@ -1051,27 +1170,27 @@ int llama_context::decode(llama_batch & inp_batch) {
1051
1170
  // extract sequence embeddings (cleared before processing each batch)
1052
1171
  auto & embd_seq_out = embd_seq;
1053
1172
 
1054
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1055
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1056
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1057
- continue;
1058
- }
1173
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1174
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1175
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1176
+
1059
1177
  embd_seq_out[seq_id].resize(n_embd);
1060
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1178
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
1061
1179
  }
1062
1180
  } break;
1063
1181
  case LLAMA_POOLING_TYPE_RANK:
1064
1182
  {
1065
- // extract the rerank score - a single float per sequence
1183
+ // extract the rerank score - n_cls_out floats per sequence
1066
1184
  auto & embd_seq_out = embd_seq;
1067
1185
 
1068
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1069
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1070
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1071
- continue;
1072
- }
1073
- embd_seq_out[seq_id].resize(1);
1074
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
1186
+ const uint32_t n_cls_out = hparams.n_cls_out;
1187
+
1188
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1189
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1190
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1191
+
1192
+ embd_seq_out[seq_id].resize(n_cls_out);
1193
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
1075
1194
  }
1076
1195
  } break;
1077
1196
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1082,23 +1201,20 @@ int llama_context::decode(llama_batch & inp_batch) {
1082
1201
  }
1083
1202
 
1084
1203
  n_outputs_prev += n_outputs;
1085
- }
1086
-
1087
- // finalize the batch processing
1088
- kv_guard.commit();
1204
+ } while (mctx->next());
1089
1205
 
1090
1206
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1091
1207
  n_outputs = n_outputs_all;
1092
1208
 
1093
1209
  // set output mappings
1094
- {
1210
+ if (n_outputs > 0) {
1095
1211
  bool sorted_output = true;
1096
1212
 
1097
- auto & out_ids = sbatch.out_ids;
1213
+ auto & out_ids = balloc->get_out_ids();
1098
1214
 
1099
- GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1215
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1100
1216
 
1101
- for (int64_t i = 0; i < n_outputs_all; ++i) {
1217
+ for (int64_t i = 0; i < n_outputs; ++i) {
1102
1218
  int64_t out_id = out_ids[i];
1103
1219
  output_ids[out_id] = i;
1104
1220
  if (out_id != i) {
@@ -1109,35 +1225,29 @@ int llama_context::decode(llama_batch & inp_batch) {
1109
1225
  // make the outputs have the same order they had in the user-provided batch
1110
1226
  // note: this is mostly relevant for recurrent models atm
1111
1227
  if (!sorted_output) {
1112
- const uint32_t n_vocab = model.vocab.n_tokens();
1113
- const uint32_t n_embd = model.hparams.n_embd;
1114
-
1115
1228
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1116
1229
 
1117
1230
  // TODO: is there something more efficient which also minimizes swaps?
1118
1231
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1119
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1120
- int32_t j_min = i;
1121
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1232
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1233
+ uint32_t j_min = i;
1234
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
1122
1235
  if (out_ids[j] < out_ids[j_min]) {
1123
1236
  j_min = j;
1124
1237
  }
1125
1238
  }
1126
- if (j_min == i) { continue; }
1127
- std::swap(out_ids[i], out_ids[j_min]);
1128
- if (logits_size > 0) {
1129
- for (uint32_t k = 0; k < n_vocab; k++) {
1130
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1131
- }
1132
- }
1133
- if (embd_size > 0) {
1134
- for (uint32_t k = 0; k < n_embd; k++) {
1135
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1136
- }
1239
+ if (j_min == i) {
1240
+ continue;
1137
1241
  }
1242
+ std::swap(out_ids[i], out_ids[j_min]);
1243
+
1244
+ // remember the swaps and apply them lazily upon logits/embeddings access
1245
+ output_swaps.push_back({ i, j_min });
1138
1246
  }
1247
+
1139
1248
  std::fill(output_ids.begin(), output_ids.end(), -1);
1140
- for (int32_t i = 0; i < n_outputs; ++i) {
1249
+
1250
+ for (uint32_t i = 0; i < n_outputs; ++i) {
1141
1251
  output_ids[out_ids[i]] = i;
1142
1252
  }
1143
1253
  }
@@ -1146,15 +1256,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1146
1256
  // wait for the computation to finish (automatically done when obtaining the model output)
1147
1257
  //synchronize();
1148
1258
 
1149
- // decide if we need to defrag the kv cache
1150
- if (cparams.defrag_thold > 0.0f) {
1151
- kv_self->defrag_sched(cparams.defrag_thold);
1152
- }
1153
-
1154
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1155
- // overlap with device computation.
1156
- ggml_backend_sched_reset(sched.get());
1157
-
1158
1259
  return 0;
1159
1260
  }
1160
1261
 
@@ -1162,7 +1263,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1162
1263
  // output
1163
1264
  //
1164
1265
 
1165
- int32_t llama_context::output_reserve(int32_t n_outputs) {
1266
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1166
1267
  const auto & hparams = model.hparams;
1167
1268
  const auto & vocab = model.vocab;
1168
1269
 
@@ -1172,9 +1273,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1172
1273
  const auto n_vocab = vocab.n_tokens();
1173
1274
  const auto n_embd = hparams.n_embd;
1174
1275
 
1175
- // TODO: use a per-batch flag for logits presence instead
1176
- bool has_logits = !cparams.embeddings;
1177
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1276
+ bool has_logits = true;
1277
+ bool has_embd = cparams.embeddings;
1178
1278
 
1179
1279
  // TODO: hacky enc-dec support
1180
1280
  if (model.arch == LLM_ARCH_T5) {
@@ -1228,53 +1328,114 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1228
1328
  // set all ids as invalid (negative)
1229
1329
  std::fill(output_ids.begin(), output_ids.end(), -1);
1230
1330
 
1231
- this->n_outputs = 0;
1232
- this->n_outputs_max = n_outputs_max;
1331
+ this->n_outputs = 0;
1233
1332
 
1234
1333
  return n_outputs_max;
1235
1334
  }
1236
1335
 
1336
+ void llama_context::output_reorder() {
1337
+ const uint64_t n_vocab = model.vocab.n_tokens();
1338
+ const uint64_t n_embd = model.hparams.n_embd;
1339
+
1340
+ for (size_t s = 0; s < output_swaps.size(); ++s) {
1341
+ const uint64_t i0 = output_swaps[s].i0;
1342
+ const uint64_t i1 = output_swaps[s].i1;
1343
+
1344
+ if (logits_size > 0) {
1345
+ for (uint64_t k = 0; k < n_vocab; k++) {
1346
+ std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1347
+ }
1348
+ }
1349
+
1350
+ if (embd_size > 0) {
1351
+ for (uint64_t k = 0; k < n_embd; k++) {
1352
+ std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1353
+ }
1354
+ }
1355
+ }
1356
+
1357
+ output_swaps.clear();
1358
+ }
1359
+
1237
1360
  //
1238
1361
  // graph
1239
1362
  //
1240
1363
 
1241
- int32_t llama_context::graph_max_nodes() const {
1242
- return std::max<int32_t>(65536, 5*model.n_tensors());
1364
+ uint32_t llama_context::graph_max_nodes() const {
1365
+ return std::max<uint32_t>(1024u, 8u*model.n_tensors());
1243
1366
  }
1244
1367
 
1245
- ggml_cgraph * llama_context::graph_init() {
1246
- ggml_init_params params = {
1247
- /*.mem_size =*/ buf_compute_meta.size(),
1248
- /*.mem_buffer =*/ buf_compute_meta.data(),
1249
- /*.no_alloc =*/ true,
1250
- };
1368
+ llm_graph_result * llama_context::get_gf_res_reserve() const {
1369
+ return static_cast<llm_graph_result *>(gf_res_reserve.get());
1370
+ }
1251
1371
 
1252
- ctx_compute.reset(ggml_init(params));
1372
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
1373
+ LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1374
+ GGML_ASSERT(n_outputs >= 1);
1253
1375
 
1254
- return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1255
- }
1376
+ if (n_tokens % n_seqs != 0) {
1377
+ n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1378
+ n_outputs = std::min(n_outputs, n_tokens);
1256
1379
 
1257
- llm_graph_result_ptr llama_context::graph_build(
1258
- ggml_context * ctx,
1259
- ggml_cgraph * gf,
1260
- const llama_ubatch & ubatch,
1261
- llm_graph_type gtype) {
1262
- return model.build_graph(
1263
- {
1264
- /*.ctx =*/ ctx,
1265
- /*.arch =*/ model.arch,
1266
- /*.hparams =*/ model.hparams,
1267
- /*.cparams =*/ cparams,
1268
- /*.ubatch =*/ ubatch,
1269
- /*.sched =*/ sched.get(),
1270
- /*.backend_cpu =*/ backend_cpu,
1271
- /*.cvec =*/ &cvec,
1272
- /*.loras =*/ &loras,
1273
- /*.memory =*/ memory.get(),
1274
- /*.cross =*/ &cross,
1275
- /*.n_outputs =*/ n_outputs,
1276
- /*.cb =*/ graph_get_cb(),
1277
- }, gf, gtype);
1380
+ LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1381
+ }
1382
+
1383
+ ggml_backend_sched_reset(sched.get());
1384
+
1385
+ // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
1386
+ gf_res_prev->reset();
1387
+
1388
+ // store the n_outputs as it is, and restore it afterwards
1389
+ // TODO: not sure if needed, might simplify in the future by removing this
1390
+ const auto save_n_outputs = this->n_outputs;
1391
+
1392
+ this->n_outputs = n_outputs;
1393
+
1394
+ llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1395
+ llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1396
+
1397
+ auto * res = gf_res_reserve.get();
1398
+
1399
+ const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
1400
+
1401
+ res->reset();
1402
+
1403
+ auto * gf = model.build_graph(gparams);
1404
+
1405
+ this->n_outputs = save_n_outputs;
1406
+
1407
+ // initialize scheduler with the specified graph
1408
+ if (split_only) {
1409
+ ggml_backend_sched_split_graph(sched.get(), gf);
1410
+ } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1411
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1412
+ return nullptr;
1413
+ }
1414
+
1415
+ return gf;
1416
+ }
1417
+
1418
+ llm_graph_params llama_context::graph_params(
1419
+ llm_graph_result * res,
1420
+ const llama_ubatch & ubatch,
1421
+ const llama_memory_context_i * mctx,
1422
+ llm_graph_type gtype) const {
1423
+ return {
1424
+ /*.arch =*/ model.arch,
1425
+ /*.hparams =*/ model.hparams,
1426
+ /*.cparams =*/ cparams,
1427
+ /*.ubatch =*/ ubatch,
1428
+ /*.gtype =*/ gtype,
1429
+ /*.sched =*/ sched.get(),
1430
+ /*.backend_cpu =*/ backend_cpu,
1431
+ /*.cvec =*/ &cvec,
1432
+ /*.loras =*/ &loras,
1433
+ /*.mctx =*/ mctx,
1434
+ /*.cross =*/ &cross,
1435
+ /*.n_outputs =*/ n_outputs,
1436
+ /*.cb =*/ graph_get_cb(),
1437
+ /*.res =*/ res,
1438
+ };
1278
1439
  }
1279
1440
 
1280
1441
  ggml_status llama_context::graph_compute(
@@ -1286,7 +1447,9 @@ ggml_status llama_context::graph_compute(
1286
1447
  if (backend_cpu != nullptr) {
1287
1448
  auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
1288
1449
  auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
1289
- set_threadpool_fn(backend_cpu, tp);
1450
+ if (set_threadpool_fn) {
1451
+ set_threadpool_fn(backend_cpu, tp);
1452
+ }
1290
1453
  }
1291
1454
 
1292
1455
  // set the number of threads for all the backends
@@ -1505,30 +1668,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1505
1668
  }
1506
1669
  }
1507
1670
 
1508
- size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
1671
+ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
1509
1672
  llama_io_write_dummy io;
1510
1673
  try {
1511
- return state_seq_write_data(io, seq_id);
1674
+ return state_seq_write_data(io, seq_id, flags);
1512
1675
  } catch (const std::exception & err) {
1513
1676
  LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1514
1677
  return 0;
1515
1678
  }
1516
1679
  }
1517
1680
 
1518
- size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
1681
+ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
1519
1682
  llama_io_write_buffer io(dst, size);
1520
1683
  try {
1521
- return state_seq_write_data(io, seq_id);
1684
+ return state_seq_write_data(io, seq_id, flags);
1522
1685
  } catch (const std::exception & err) {
1523
1686
  LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1524
1687
  return 0;
1525
1688
  }
1526
1689
  }
1527
1690
 
1528
- size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
1691
+ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
1529
1692
  llama_io_read_buffer io(src, size);
1530
1693
  try {
1531
- return state_seq_read_data(io, seq_id);
1694
+ return state_seq_read_data(io, seq_id, flags);
1532
1695
  } catch (const std::exception & err) {
1533
1696
  LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1534
1697
  return 0;
@@ -1626,7 +1789,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
1626
1789
  {
1627
1790
  const size_t state_size = file.size() - file.tell();
1628
1791
  llama_io_read_file io(&file);
1629
- const size_t nread = state_seq_read_data(io, seq_id);
1792
+ const size_t nread = state_seq_read_data(io, seq_id, 0);
1630
1793
  if (!nread) {
1631
1794
  LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1632
1795
  return 0;
@@ -1650,7 +1813,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
1650
1813
 
1651
1814
  // save the context state using stream saving
1652
1815
  llama_io_write_file io(&file);
1653
- state_seq_write_data(io, seq_id);
1816
+ state_seq_write_data(io, seq_id, 0);
1654
1817
 
1655
1818
  const size_t res = file.tell();
1656
1819
  GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@@ -1679,14 +1842,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1679
1842
 
1680
1843
  std::vector<int32_t> w_output_pos;
1681
1844
 
1682
- GGML_ASSERT(n_outputs <= n_outputs_max);
1683
-
1684
1845
  w_output_pos.resize(n_outputs);
1685
1846
 
1686
1847
  // build a more compact representation of the output ids
1687
1848
  for (size_t i = 0; i < n_batch(); ++i) {
1688
1849
  // map an output id to a position in the batch
1689
- int32_t pos = output_ids[i];
1850
+ int64_t pos = output_ids[i];
1690
1851
  if (pos >= 0) {
1691
1852
  GGML_ASSERT(pos < n_outputs);
1692
1853
  w_output_pos[pos] = i;
@@ -1726,11 +1887,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1726
1887
  }
1727
1888
  }
1728
1889
 
1729
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1730
-
1731
- if (kv_self != nullptr) {
1732
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1733
- kv_self->state_write(io);
1890
+ if (memory != nullptr) {
1891
+ LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1892
+ memory->state_write(io);
1734
1893
  }
1735
1894
 
1736
1895
  return io.n_bytes();
@@ -1815,35 +1974,29 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1815
1974
  }
1816
1975
 
1817
1976
  if (memory) {
1818
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1819
-
1820
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1977
+ LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
1821
1978
 
1822
- kv_self->state_read(io);
1979
+ memory->state_read(io);
1823
1980
  }
1824
1981
 
1825
1982
  return io.n_bytes();
1826
1983
  }
1827
1984
 
1828
- size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
1985
+ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1829
1986
  GGML_UNUSED(seq_id);
1830
1987
 
1831
1988
  if (memory) {
1832
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1833
-
1834
- kv_self->state_write(io, seq_id);
1989
+ memory->state_write(io, seq_id, flags);
1835
1990
  }
1836
1991
 
1837
1992
  return io.n_bytes();
1838
1993
  }
1839
1994
 
1840
- size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
1995
+ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1841
1996
  GGML_UNUSED(seq_id);
1842
1997
 
1843
1998
  if (memory) {
1844
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1845
-
1846
- kv_self->state_read(io, seq_id);
1999
+ memory->state_read(io, seq_id, flags);
1847
2000
  }
1848
2001
 
1849
2002
  return io.n_bytes();
@@ -1862,6 +2015,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
1862
2015
  data.t_eval_ms = 1e-3 * t_eval_us;
1863
2016
  data.n_p_eval = std::max(1, n_p_eval);
1864
2017
  data.n_eval = std::max(1, n_eval);
2018
+ data.n_reused = std::max(0, n_reused);
1865
2019
 
1866
2020
  return data;
1867
2021
  }
@@ -1870,6 +2024,22 @@ void llama_context::perf_reset() {
1870
2024
  t_start_us = ggml_time_us();
1871
2025
  t_eval_us = n_eval = 0;
1872
2026
  t_p_eval_us = n_p_eval = 0;
2027
+ n_reused = 0;
2028
+ }
2029
+
2030
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
2031
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
2032
+ for (const auto & buft_size : model.memory_breakdown()) {
2033
+ ret[buft_size.first].model += buft_size.second;
2034
+ }
2035
+ for (const auto & buft_size : memory->memory_breakdown()) {
2036
+ ret[buft_size.first].context += buft_size.second;
2037
+ }
2038
+ for (const auto & backend_ptr : backends) {
2039
+ ggml_backend_t backend = backend_ptr.get();
2040
+ ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
2041
+ }
2042
+ return ret;
1873
2043
  }
1874
2044
 
1875
2045
  //
@@ -1904,7 +2074,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
1904
2074
  opt_params.opt_period = n_batch / n_ubatch;
1905
2075
  opt_params.get_opt_pars = lopt_params.get_opt_pars;
1906
2076
  opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1907
-
2077
+ opt_params.optimizer = lopt_params.optimizer_type;
1908
2078
  opt_ctx = ggml_opt_init(opt_params);
1909
2079
 
1910
2080
  llama_opt_param_filter param_filter = lopt_params.param_filter;
@@ -1948,10 +2118,7 @@ void llama_context::opt_epoch_iter(
1948
2118
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1949
2119
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1950
2120
 
1951
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1952
-
1953
- kv_self->clear();
1954
- llama_kv_cache_guard kv_guard(kv_self);
2121
+ memory->clear(true);
1955
2122
 
1956
2123
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1957
2124
  batch.n_tokens = n_batch;
@@ -1963,39 +2130,49 @@ void llama_context::opt_epoch_iter(
1963
2130
  batch.logits [pos_batch] = true;
1964
2131
  }
1965
2132
 
1966
- const auto n_tokens_all = batch.n_tokens;
2133
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2134
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2135
+ return;
2136
+ }
1967
2137
 
1968
- n_queued_tokens += n_tokens_all;
2138
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
1969
2139
 
1970
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1971
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2140
+ n_queued_tokens += n_tokens_all;
1972
2141
 
1973
2142
  embd_seq.clear();
1974
2143
 
1975
- int64_t n_outputs_all = n_tokens_all;
2144
+ uint32_t n_outputs_all = n_tokens_all;
1976
2145
 
1977
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
2146
+ auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2147
+ if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2148
+ LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2149
+ break;
2150
+ }
1978
2151
 
1979
2152
  // reserve output buffer
1980
2153
  if (output_reserve(n_outputs_all) < n_outputs_all) {
1981
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2154
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1982
2155
  GGML_ABORT("TODO: handle this error");
1983
2156
  };
1984
2157
 
1985
- for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1986
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
2158
+ uint32_t pos_batch = 0;
2159
+ do {
2160
+ const auto & ubatch = mctx->get_ubatch();
1987
2161
 
1988
2162
  n_outputs = ubatch.n_tokens;
1989
2163
 
1990
- // TODO: not sure if this is needed
1991
- if (!kv_self->find_slot(ubatch)) {
1992
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1993
-
1994
- GGML_ABORT("TODO: handle this error");
2164
+ if (!mctx->apply()) {
2165
+ LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
2166
+ break;
1995
2167
  }
1996
2168
 
1997
- auto * gf = graph_init();
1998
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2169
+ auto * res = gf_res_prev.get();
2170
+
2171
+ const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2172
+
2173
+ res->reset();
2174
+
2175
+ auto * gf = model.build_graph(gparams);
1999
2176
 
2000
2177
  struct ggml_context * ctx_compute_opt;
2001
2178
  {
@@ -2010,6 +2187,7 @@ void llama_context::opt_epoch_iter(
2010
2187
  }
2011
2188
  ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2012
2189
  ggml_opt_alloc(opt_ctx, train);
2190
+
2013
2191
  res->set_inputs(&ubatch);
2014
2192
  {
2015
2193
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@@ -2027,10 +2205,10 @@ void llama_context::opt_epoch_iter(
2027
2205
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2028
2206
  }
2029
2207
  ggml_free(ctx_compute_opt);
2030
- }
2031
- }
2032
2208
 
2033
- kv_guard.commit();
2209
+ pos_batch += ubatch.n_tokens;
2210
+ } while (mctx->next());
2211
+ }
2034
2212
  }
2035
2213
 
2036
2214
  void llama_context::opt_epoch(
@@ -2096,12 +2274,13 @@ llama_context_params llama_context_default_params() {
2096
2274
  /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2097
2275
  /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2098
2276
  /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2277
+ /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
2099
2278
  /*.rope_freq_base =*/ 0.0f,
2100
2279
  /*.rope_freq_scale =*/ 0.0f,
2101
2280
  /*.yarn_ext_factor =*/ -1.0f,
2102
- /*.yarn_attn_factor =*/ 1.0f,
2103
- /*.yarn_beta_fast =*/ 32.0f,
2104
- /*.yarn_beta_slow =*/ 1.0f,
2281
+ /*.yarn_attn_factor =*/ -1.0f,
2282
+ /*.yarn_beta_fast =*/ -1.0f,
2283
+ /*.yarn_beta_slow =*/ -1.0f,
2105
2284
  /*.yarn_orig_ctx =*/ 0,
2106
2285
  /*.defrag_thold =*/ -1.0f,
2107
2286
  /*.cb_eval =*/ nullptr,
@@ -2112,10 +2291,10 @@ llama_context_params llama_context_default_params() {
2112
2291
  /*.abort_callback_data =*/ nullptr,
2113
2292
  /*.embeddings =*/ false,
2114
2293
  /*.offload_kqv =*/ true,
2115
- /*.flash_attn =*/ false,
2116
2294
  /*.no_perf =*/ true,
2117
2295
  /*.op_offload =*/ true,
2118
2296
  /*.swa_full =*/ true,
2297
+ /*.kv_unified =*/ false,
2119
2298
  };
2120
2299
 
2121
2300
  return result;
@@ -2139,12 +2318,30 @@ llama_context * llama_init_from_model(
2139
2318
  return nullptr;
2140
2319
  }
2141
2320
 
2142
- if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2321
+ if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
2143
2322
  LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2144
- params.flash_attn = false;
2323
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2324
+ }
2325
+
2326
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2327
+ const uint32_t blck_size = ggml_blck_size(params.type_k);
2328
+ if (model->hparams.n_embd_head_k % blck_size != 0) {
2329
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2330
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2331
+ return nullptr;
2332
+ }
2333
+ }
2334
+
2335
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2336
+ const uint32_t blck_size = ggml_blck_size(params.type_v);
2337
+ if (model->hparams.n_embd_head_v % blck_size != 0) {
2338
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2339
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2340
+ return nullptr;
2341
+ }
2145
2342
  }
2146
2343
 
2147
- if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2344
+ if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
2148
2345
  LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2149
2346
  return nullptr;
2150
2347
  }
@@ -2190,14 +2387,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2190
2387
  return &ctx->get_model();
2191
2388
  }
2192
2389
 
2193
- llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2194
- return ctx->get_kv_self();
2195
- }
2196
-
2197
- void llama_kv_self_update(llama_context * ctx) {
2198
- ctx->kv_self_update();
2199
- }
2200
-
2201
2390
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
2202
2391
  return ctx->pooling_type();
2203
2392
  }
@@ -2311,160 +2500,108 @@ int32_t llama_apply_adapter_cvec(
2311
2500
  }
2312
2501
 
2313
2502
  //
2314
- // kv cache
2503
+ // memory
2315
2504
  //
2316
2505
 
2317
- // deprecated
2318
- int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2319
- const auto * kv = ctx->get_kv_self();
2320
- if (!kv) {
2321
- return 0;
2322
- }
2323
-
2324
- int32_t res = 0;
2325
-
2326
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2327
- const llama_pos p0 = kv->seq_pos_min(s);
2328
- const llama_pos p1 = kv->seq_pos_max(s);
2329
-
2330
- if (p0 >= 0) {
2331
- res += (p1 - p0) + 1;
2332
- }
2333
- }
2334
-
2335
- return res;
2336
- }
2337
-
2338
- // deprecated
2339
- // note: this is the same as above - will be removed anyway, so it's ok
2340
- int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2341
- const auto * kv = ctx->get_kv_self();
2342
- if (!kv) {
2343
- return 0;
2344
- }
2345
-
2346
- int32_t res = 0;
2347
-
2348
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2349
- const llama_pos p0 = kv->seq_pos_min(s);
2350
- const llama_pos p1 = kv->seq_pos_max(s);
2351
-
2352
- if (p0 >= 0) {
2353
- res += (p1 - p0) + 1;
2354
- }
2355
- }
2356
-
2357
- return res;
2506
+ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2507
+ return ctx->get_memory();
2358
2508
  }
2359
2509
 
2360
- void llama_kv_self_clear(llama_context * ctx) {
2361
- auto * kv = ctx->get_kv_self();
2362
- if (!kv) {
2510
+ void llama_memory_clear(llama_memory_t mem, bool data) {
2511
+ if (!mem) {
2363
2512
  return;
2364
2513
  }
2365
2514
 
2366
- kv->clear();
2515
+ mem->clear(data);
2367
2516
  }
2368
2517
 
2369
- bool llama_kv_self_seq_rm(
2370
- llama_context * ctx,
2371
- llama_seq_id seq_id,
2372
- llama_pos p0,
2373
- llama_pos p1) {
2374
- auto * kv = ctx->get_kv_self();
2375
- if (!kv) {
2518
+ bool llama_memory_seq_rm(
2519
+ llama_memory_t mem,
2520
+ llama_seq_id seq_id,
2521
+ llama_pos p0,
2522
+ llama_pos p1) {
2523
+ if (!mem) {
2376
2524
  return true;
2377
2525
  }
2378
2526
 
2379
- return kv->seq_rm(seq_id, p0, p1);
2527
+ return mem->seq_rm(seq_id, p0, p1);
2380
2528
  }
2381
2529
 
2382
- void llama_kv_self_seq_cp(
2383
- llama_context * ctx,
2384
- llama_seq_id seq_id_src,
2385
- llama_seq_id seq_id_dst,
2386
- llama_pos p0,
2387
- llama_pos p1) {
2388
- auto * kv = ctx->get_kv_self();
2389
- if (!kv) {
2530
+ void llama_memory_seq_cp(
2531
+ llama_memory_t mem,
2532
+ llama_seq_id seq_id_src,
2533
+ llama_seq_id seq_id_dst,
2534
+ llama_pos p0,
2535
+ llama_pos p1) {
2536
+ if (!mem) {
2390
2537
  return;
2391
2538
  }
2392
2539
 
2393
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2540
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2394
2541
  }
2395
2542
 
2396
- void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2397
- auto * kv = ctx->get_kv_self();
2398
- if (!kv) {
2543
+ void llama_memory_seq_keep(
2544
+ llama_memory_t mem,
2545
+ llama_seq_id seq_id) {
2546
+ if (!mem) {
2399
2547
  return;
2400
2548
  }
2401
2549
 
2402
- kv->seq_keep(seq_id);
2550
+ mem->seq_keep(seq_id);
2403
2551
  }
2404
2552
 
2405
- void llama_kv_self_seq_add(
2406
- llama_context * ctx,
2407
- llama_seq_id seq_id,
2408
- llama_pos p0,
2409
- llama_pos p1,
2410
- llama_pos delta) {
2411
- auto * kv = ctx->get_kv_self();
2412
- if (!kv) {
2553
+ void llama_memory_seq_add(
2554
+ llama_memory_t mem,
2555
+ llama_seq_id seq_id,
2556
+ llama_pos p0,
2557
+ llama_pos p1,
2558
+ llama_pos delta) {
2559
+ if (!mem) {
2413
2560
  return;
2414
2561
  }
2415
2562
 
2416
- kv->seq_add(seq_id, p0, p1, delta);
2563
+ mem->seq_add(seq_id, p0, p1, delta);
2417
2564
  }
2418
2565
 
2419
- void llama_kv_self_seq_div(
2420
- llama_context * ctx,
2421
- llama_seq_id seq_id,
2422
- llama_pos p0,
2423
- llama_pos p1,
2424
- int d) {
2425
- auto * kv = ctx->get_kv_self();
2426
- if (!kv) {
2566
+ void llama_memory_seq_div(
2567
+ llama_memory_t mem,
2568
+ llama_seq_id seq_id,
2569
+ llama_pos p0,
2570
+ llama_pos p1,
2571
+ int d) {
2572
+ if (!mem) {
2427
2573
  return;
2428
2574
  }
2429
2575
 
2430
- kv->seq_div(seq_id, p0, p1, d);
2576
+ mem->seq_div(seq_id, p0, p1, d);
2431
2577
  }
2432
2578
 
2433
- llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2434
- const auto * kv = ctx->get_kv_self();
2435
- if (!kv) {
2579
+ llama_pos llama_memory_seq_pos_min(
2580
+ llama_memory_t mem,
2581
+ llama_seq_id seq_id) {
2582
+ if (!mem) {
2436
2583
  return -1;
2437
2584
  }
2438
2585
 
2439
- return kv->seq_pos_min(seq_id);
2586
+ return mem->seq_pos_min(seq_id);
2440
2587
  }
2441
2588
 
2442
- llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2443
- const auto * kv = ctx->get_kv_self();
2444
- if (!kv) {
2589
+ llama_pos llama_memory_seq_pos_max(
2590
+ llama_memory_t mem,
2591
+ llama_seq_id seq_id) {
2592
+ if (!mem) {
2445
2593
  return -1;
2446
2594
  }
2447
2595
 
2448
- return kv->seq_pos_max(seq_id);
2596
+ return mem->seq_pos_max(seq_id);
2449
2597
  }
2450
2598
 
2451
- void llama_kv_self_defrag(llama_context * ctx) {
2452
- auto * kv = ctx->get_kv_self();
2453
- if (!kv) {
2454
- return;
2455
- }
2456
-
2457
- // force defrag
2458
- kv->defrag_sched(-1.0f);
2459
- }
2460
-
2461
- bool llama_kv_self_can_shift(const llama_context * ctx) {
2462
- const auto * kv = ctx->get_kv_self();
2463
- if (!kv) {
2599
+ bool llama_memory_can_shift(llama_memory_t mem) {
2600
+ if (!mem) {
2464
2601
  return false;
2465
2602
  }
2466
2603
 
2467
- return kv->get_can_shift();
2604
+ return mem->get_can_shift();
2468
2605
  }
2469
2606
 
2470
2607
  // llama state API
@@ -2536,19 +2673,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
2536
2673
  }
2537
2674
 
2538
2675
  size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2539
- return ctx->state_seq_get_size(seq_id);
2676
+ return llama_state_seq_get_size_ext(ctx, seq_id, 0);
2540
2677
  }
2541
2678
 
2542
2679
  size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2680
+ return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
2681
+ }
2682
+
2683
+ size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2684
+ return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
2685
+ }
2686
+
2687
+ size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
2688
+ return ctx->state_seq_get_size(seq_id, flags);
2689
+ }
2690
+
2691
+ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2543
2692
  ctx->synchronize();
2544
2693
 
2545
- return ctx->state_seq_get_data(seq_id, dst, size);
2694
+ return ctx->state_seq_get_data(seq_id, dst, size, flags);
2546
2695
  }
2547
2696
 
2548
- size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2697
+ size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2549
2698
  ctx->synchronize();
2550
2699
 
2551
- return ctx->state_seq_set_data(seq_id, src, size);
2700
+ return ctx->state_seq_set_data(seq_id, src, size, flags);
2552
2701
  }
2553
2702
 
2554
2703
  size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
@@ -2589,22 +2738,8 @@ int32_t llama_encode(
2589
2738
  int32_t llama_decode(
2590
2739
  llama_context * ctx,
2591
2740
  llama_batch batch) {
2592
- int ret = ctx->decode(batch);
2593
-
2594
- // defrag and try again
2595
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596
- if (ret == 1) {
2597
- llama_kv_self_defrag(ctx);
2598
- ret = ctx->decode(batch);
2599
-
2600
- if (ret == 1) {
2601
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602
-
2603
- return ret;
2604
- }
2605
- }
2606
-
2607
- if (ret != 0) {
2741
+ const int ret = ctx->decode(batch);
2742
+ if (ret != 0 && ret != 1) {
2608
2743
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2609
2744
  }
2610
2745
 
@@ -2638,12 +2773,149 @@ void llama_perf_context_print(const llama_context * ctx) {
2638
2773
  LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2639
2774
  __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2640
2775
  LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2776
+ LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
2641
2777
  }
2642
2778
 
2643
2779
  void llama_perf_context_reset(llama_context * ctx) {
2644
2780
  ctx->perf_reset();
2645
2781
  }
2646
2782
 
2783
+ void llama_memory_breakdown_print(const struct llama_context * ctx) {
2784
+ const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
2785
+
2786
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
2787
+
2788
+ std::vector<std::array<std::string, 9>> table_data;
2789
+ table_data.reserve(devices.size());
2790
+ const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n";
2791
+ const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n";
2792
+ const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n";
2793
+
2794
+ table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"});
2795
+
2796
+ constexpr size_t MiB = 1024 * 1024;
2797
+ const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "};
2798
+
2799
+ // track seen buffer types to avoid double counting:
2800
+ std::set<ggml_backend_buffer_type_t> seen_buffer_types;
2801
+
2802
+ // accumulative memory breakdown for each device and for host:
2803
+ std::vector<llama_memory_breakdown_data> mb_dev(devices.size());
2804
+ llama_memory_breakdown_data mb_host;
2805
+
2806
+ for (const auto & buft_mb : memory_breakdown) {
2807
+ ggml_backend_buffer_type_t buft = buft_mb.first;
2808
+ const llama_memory_breakdown_data & mb = buft_mb.second;
2809
+ if (ggml_backend_buft_is_host(buft)) {
2810
+ mb_host.model += mb.model;
2811
+ mb_host.context += mb.context;
2812
+ mb_host.compute += mb.compute;
2813
+ seen_buffer_types.insert(buft);
2814
+ continue;
2815
+ }
2816
+ ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
2817
+ if (dev) {
2818
+ int i_dev = -1;
2819
+ for (size_t i = 0; i < devices.size(); i++) {
2820
+ if (devices[i] == dev) {
2821
+ i_dev = i;
2822
+ break;
2823
+ }
2824
+ }
2825
+ if (i_dev != -1) {
2826
+ mb_dev[i_dev].model += mb.model;
2827
+ mb_dev[i_dev].context += mb.context;
2828
+ mb_dev[i_dev].compute += mb.compute;
2829
+ seen_buffer_types.insert(buft);
2830
+ continue;
2831
+ }
2832
+ }
2833
+ }
2834
+
2835
+ // print memory breakdown for each device:
2836
+ for (size_t i = 0; i < devices.size(); i++) {
2837
+ ggml_backend_dev_t dev = devices[i];
2838
+ llama_memory_breakdown_data mb = mb_dev[i];
2839
+
2840
+ const std::string name = ggml_backend_dev_name(dev);
2841
+ std::string desc = ggml_backend_dev_description(dev);
2842
+ for (const std::string & prefix : desc_prefixes_strip) {
2843
+ if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) {
2844
+ desc = desc.substr(prefix.length());
2845
+ }
2846
+ }
2847
+
2848
+ size_t free, total;
2849
+ ggml_backend_dev_memory(dev, &free, &total);
2850
+
2851
+ const size_t self = mb.model + mb.context + mb.compute;
2852
+ const size_t unaccounted = total - self - free;
2853
+
2854
+ table_data.push_back({
2855
+ template_gpu,
2856
+ " - " + name + " (" + desc + ")",
2857
+ std::to_string(total / MiB),
2858
+ std::to_string(free / MiB),
2859
+ std::to_string(self / MiB),
2860
+ std::to_string(mb.model / MiB),
2861
+ std::to_string(mb.context / MiB),
2862
+ std::to_string(mb.compute / MiB),
2863
+ std::to_string(unaccounted / MiB)});
2864
+ }
2865
+
2866
+ // print memory breakdown for host:
2867
+ {
2868
+ const size_t self = mb_host.model + mb_host.context + mb_host.compute;
2869
+ table_data.push_back({
2870
+ template_other,
2871
+ " - Host",
2872
+ "", // total
2873
+ "", // free
2874
+ std::to_string(self / MiB),
2875
+ std::to_string(mb_host.model / MiB),
2876
+ std::to_string(mb_host.context / MiB),
2877
+ std::to_string(mb_host.compute / MiB),
2878
+ ""}); // unaccounted
2879
+ }
2880
+
2881
+ // print memory breakdown for all remaining buffer types:
2882
+ for (const auto & buft_mb : memory_breakdown) {
2883
+ ggml_backend_buffer_type_t buft = buft_mb.first;
2884
+ const llama_memory_breakdown_data & mb = buft_mb.second;
2885
+ if (seen_buffer_types.count(buft) == 1) {
2886
+ continue;
2887
+ }
2888
+ const std::string name = ggml_backend_buft_name(buft);
2889
+ const size_t self = mb.model + mb.context + mb.compute;
2890
+ table_data.push_back({
2891
+ template_other,
2892
+ " - " + name,
2893
+ "", // total
2894
+ "", // free
2895
+ std::to_string(self / MiB),
2896
+ std::to_string(mb.model / MiB),
2897
+ std::to_string(mb.context / MiB),
2898
+ std::to_string(mb.compute / MiB),
2899
+ ""}); // unaccounted
2900
+ seen_buffer_types.insert(buft);
2901
+ }
2902
+
2903
+ for (size_t j = 1; j < table_data[0].size(); j++) {
2904
+ size_t max_len = 0;
2905
+ for (const auto & td : table_data) {
2906
+ max_len = std::max(max_len, td[j].length());
2907
+ }
2908
+ for (auto & td : table_data) {
2909
+ td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' ');
2910
+ }
2911
+ }
2912
+ for (const auto & td : table_data) {
2913
+ LLAMA_LOG_INFO(td[0].c_str(),
2914
+ __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(),
2915
+ td[6].c_str(), td[7].c_str(), td[8].c_str());
2916
+ }
2917
+ }
2918
+
2647
2919
  //
2648
2920
  // training
2649
2921
  //