whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -1,5 +1,6 @@
1
1
  #include "llama-context.h"
2
2
 
3
+ #include "llama-arch.h"
3
4
  #include "llama-impl.h"
4
5
  #include "llama-batch.h"
5
6
  #include "llama-io.h"
@@ -8,6 +9,7 @@
8
9
  #include "llama-model.h"
9
10
 
10
11
  #include <cinttypes>
12
+ #include <cmath>
11
13
  #include <cstring>
12
14
  #include <limits>
13
15
  #include <stdexcept>
@@ -21,6 +23,8 @@ llama_context::llama_context(
21
23
  llama_context_params params) :
22
24
  model(model),
23
25
  balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
26
+ // TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
27
+ // may need to be backend-dependent
24
28
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
25
29
 
26
30
  t_start_us = model.t_start_us;
@@ -35,14 +39,12 @@ llama_context::llama_context(
35
39
 
36
40
  cparams.n_threads = params.n_threads;
37
41
  cparams.n_threads_batch = params.n_threads_batch;
38
- cparams.yarn_ext_factor = params.yarn_ext_factor;
39
- cparams.yarn_attn_factor = params.yarn_attn_factor;
40
- cparams.yarn_beta_fast = params.yarn_beta_fast;
41
- cparams.yarn_beta_slow = params.yarn_beta_slow;
42
- cparams.defrag_thold = params.defrag_thold;
42
+ cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
43
+ cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
44
+ cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
45
+ cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
43
46
  cparams.embeddings = params.embeddings;
44
47
  cparams.offload_kqv = params.offload_kqv;
45
- cparams.flash_attn = params.flash_attn;
46
48
  cparams.no_perf = params.no_perf;
47
49
  cparams.pooling_type = params.pooling_type;
48
50
  cparams.warmup = false;
@@ -58,6 +60,25 @@ llama_context::llama_context(
58
60
  cparams.cb_eval = params.cb_eval;
59
61
  cparams.cb_eval_user_data = params.cb_eval_user_data;
60
62
 
63
+ // Initialize backend samplers here so they are part of the sampling graph
64
+ // before the reserve passes run later in this function. This avoids a later
65
+ // re-reserve when graph nodes change.
66
+ if (params.samplers != nullptr && params.n_samplers > 0) {
67
+ for (size_t i = 0; i < params.n_samplers; ++i) {
68
+ const auto & config = params.samplers[i];
69
+
70
+ if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
71
+ throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
72
+ }
73
+
74
+ if (set_sampler(config.seq_id, config.sampler)) {
75
+ const int n_samplers = llama_sampler_chain_n(config.sampler);
76
+
77
+ LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
78
+ }
79
+ }
80
+ }
81
+
61
82
  auto rope_scaling_type = params.rope_scaling_type;
62
83
  if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
63
84
  rope_scaling_type = hparams.rope_scaling_type_train;
@@ -71,6 +92,43 @@ llama_context::llama_context(
71
92
  cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
72
93
  }
73
94
 
95
+ if (cparams.yarn_ext_factor != 0) {
96
+ static auto get_mscale = [](float scale, float mscale) {
97
+ return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
98
+ };
99
+
100
+ const float factor = 1.0f / cparams.rope_freq_scale;
101
+
102
+ // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
103
+ if (hparams.rope_yarn_log_mul != 0.0f) {
104
+ // note: here we assume `mscale == 1.0f`
105
+ // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
106
+ float mscale = 1.0f;
107
+ const float mscale_all_dims = hparams.rope_yarn_log_mul;
108
+
109
+ // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
110
+ // special-case DEEPSEEK v2:
111
+ // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
112
+ if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
113
+ mscale = mscale_all_dims;
114
+ }
115
+
116
+ cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
117
+
118
+ LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
119
+ __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
120
+ } else {
121
+ cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
122
+ }
123
+
124
+ // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
125
+ // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
126
+ //
127
+ // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
128
+ // https://github.com/ggml-org/llama.cpp/pull/17945
129
+ cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
130
+ }
131
+
74
132
  cparams.yarn_attn_factor *= hparams.rope_attn_factor;
75
133
 
76
134
  if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -87,47 +145,63 @@ llama_context::llama_context(
87
145
  cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
88
146
  }
89
147
 
148
+ cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
149
+
90
150
  // with causal attention, the batch size is limited by the context size
91
151
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
92
152
 
93
- // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
94
- // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
95
- // ref: https://github.com/ggerganov/llama.cpp/pull/5021
96
- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
97
- if (cparams.n_batch < GGML_KQ_MASK_PAD) {
98
- LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
- cparams.n_batch = GGML_KQ_MASK_PAD;
100
- }
101
-
102
153
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
103
154
 
104
155
  cparams.op_offload = params.op_offload;
156
+ cparams.kv_unified = params.kv_unified;
157
+
158
+ {
159
+ const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
160
+ graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
161
+
162
+ if (graph_reuse_disable) {
163
+ LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__);
164
+ }
165
+ }
166
+
167
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
168
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
169
+
170
+ if (cparams.kv_unified) {
171
+ cparams.n_ctx_seq = cparams.n_ctx;
172
+ } else {
173
+ cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
174
+ cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
175
+
176
+ if (cparams.n_ctx_seq == 0) {
177
+ throw std::runtime_error("n_ctx_seq == 0");
178
+ }
105
179
 
106
- const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
180
+ if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
181
+ cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
182
+ LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
183
+ }
184
+ }
107
185
 
108
186
  LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
109
187
  LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
110
- LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
188
+ LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
111
189
  LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
112
190
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
113
191
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
114
- LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
192
+ LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
193
+ LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
115
194
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
116
195
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
117
196
 
118
- if (n_ctx_per_seq < hparams.n_ctx_train) {
119
- LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
120
- __func__, n_ctx_per_seq, hparams.n_ctx_train);
197
+ if (cparams.n_ctx_seq < hparams.n_ctx_train) {
198
+ LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
199
+ __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
121
200
  }
122
201
 
123
- if (n_ctx_per_seq > hparams.n_ctx_train) {
124
- LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
125
- __func__, n_ctx_per_seq, hparams.n_ctx_train);
126
- }
127
-
128
- if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
129
- LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
130
- __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
202
+ if (cparams.n_ctx_seq > hparams.n_ctx_train) {
203
+ LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
204
+ __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
131
205
  }
132
206
 
133
207
  if (!hparams.vocab_only) {
@@ -176,7 +250,10 @@ llama_context::llama_context(
176
250
  // graph outputs buffer
177
251
  {
178
252
  // resized during inference when a batch uses more outputs
179
- if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
253
+ // Create a dummy batch for initialization.
254
+ llama_batch dummy_batch = {};
255
+ dummy_batch.n_tokens = 0;
256
+ if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
180
257
  throw std::runtime_error("failed to reserve initial output buffer");
181
258
  }
182
259
 
@@ -203,6 +280,7 @@ llama_context::llama_context(
203
280
 
204
281
  backend_buft.clear();
205
282
  backend_ptrs.clear();
283
+ backend_buf_exp_size.clear();
206
284
 
207
285
  for (auto & backend : backends) {
208
286
  auto * buft = ggml_backend_get_default_buffer_type(backend.get());
@@ -219,23 +297,27 @@ llama_context::llama_context(
219
297
 
220
298
  backend_buft.push_back(buft);
221
299
  backend_ptrs.push_back(backend.get());
300
+ backend_buf_exp_size.push_back(0);
222
301
  }
223
302
 
224
303
  LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
225
304
 
226
- const size_t max_nodes = this->graph_max_nodes();
305
+ const uint32_t n_seqs = cparams.n_seq_max;
306
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
307
+
308
+ const size_t max_nodes = this->graph_max_nodes(n_tokens);
227
309
 
228
310
  LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
229
311
 
230
- // buffer used to store the computation graph and the tensor meta data
231
- buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
312
+ gf_res_prev.reset(new llm_graph_result(max_nodes));
313
+ gf_res_reserve.reset(new llm_graph_result(max_nodes));
232
314
 
233
315
  // TODO: move these checks to ggml_backend_sched
234
316
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
235
317
  bool pipeline_parallel =
236
318
  model.n_devices() > 1 &&
237
- model.params.n_gpu_layers > (int) model.hparams.n_layer &&
238
- model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
319
+ model.n_gpu_layers() > model.hparams.n_layer &&
320
+ model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
239
321
  cparams.offload_kqv &&
240
322
  !model.has_tensor_overrides();
241
323
 
@@ -263,44 +345,94 @@ llama_context::llama_context(
263
345
  if (pipeline_parallel) {
264
346
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
265
347
  }
266
- }
267
348
 
268
- // reserve worst-case graph
269
- if (!hparams.vocab_only && memory) {
270
- const uint32_t n_seqs = cparams.n_seq_max;
271
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
349
+ llama_memory_context_ptr mctx;
350
+ if (memory) {
351
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
352
+ mctx = memory->init_full();
353
+ if (!mctx) {
354
+ throw std::runtime_error("failed to initialize memory module");
355
+ }
356
+ }
357
+
358
+ cross.v_embd.clear();
359
+
360
+ // avoid reserving graphs with zero outputs - assume one output per sequence
361
+ n_outputs = n_seqs;
272
362
 
273
363
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
274
364
 
365
+ // resolve automatic Flash Attention use
366
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
367
+ auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
368
+ if (!gf) {
369
+ throw std::runtime_error("failed to split graph for Flash Attention check");
370
+ }
371
+
372
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
373
+ bool fa_device_mismatch = false;
374
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
375
+ ggml_tensor * n = ggml_graph_node(gf, i);
376
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
377
+ continue;
378
+ }
379
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
380
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
381
+
382
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
383
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
384
+ const int il = std::stoi(n->name + prefix_len);
385
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
386
+ if (device_fa != device_kv) {
387
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
388
+ "is assigned to device %s (usually due to missing support)\n",
389
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
390
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
391
+ fa_device_mismatch = true;
392
+ break;
393
+ }
394
+ }
395
+ if (fa_device_mismatch) {
396
+ cparams.flash_attn = false;
397
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
398
+ if (ggml_is_quantized(params.type_v)) {
399
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
400
+ }
401
+ } else {
402
+ cparams.flash_attn = true;
403
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
404
+ }
405
+ }
406
+
407
+ // reserve worst-case graph
275
408
  int n_splits_pp = -1;
276
409
  int n_nodes_pp = -1;
277
410
 
278
411
  int n_splits_tg = -1;
279
412
  int n_nodes_tg = -1;
280
413
 
281
- // simulate full KV cache
282
-
283
- const auto mctx = memory->init_full();
284
- if (!mctx) {
285
- throw std::runtime_error("failed to initialize KV cache");
286
- }
287
-
288
- cross.v_embd.clear();
289
-
290
- // reserve pp graph first so that buffers are only allocated once
414
+ // reserve pp (prompt processing) graph first so that buffers are only allocated once
291
415
  {
292
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
416
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
417
+ model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
293
418
  if (!gf) {
294
- throw std::runtime_error("failed to allocate compute pp buffers");
419
+ if (pipeline_parallel) {
420
+ LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
421
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
422
+ gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
423
+ }
424
+ if (!gf) {
425
+ throw std::runtime_error("failed to allocate compute pp buffers");
426
+ }
295
427
  }
296
428
 
297
429
  n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
298
430
  n_nodes_pp = ggml_graph_n_nodes(gf);
299
431
  }
300
432
 
301
- // reserve with tg graph to get the number of splits and nodes
433
+ // reserve with tg (token generation) graph to get the number of splits and nodes
302
434
  {
303
- auto * gf = graph_reserve(1, 1, 1, mctx.get());
435
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
304
436
  if (!gf) {
305
437
  throw std::runtime_error("failed to allocate compute tg buffers");
306
438
  }
@@ -311,7 +443,11 @@ llama_context::llama_context(
311
443
 
312
444
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
313
445
  {
314
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
446
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
447
+ //
448
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
449
+ //
450
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
315
451
  if (!gf) {
316
452
  throw std::runtime_error("failed to allocate compute pp buffers");
317
453
  }
@@ -320,11 +456,13 @@ llama_context::llama_context(
320
456
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
321
457
  ggml_backend_t backend = backend_ptrs[i];
322
458
  ggml_backend_buffer_type_t buft = backend_buft[i];
323
- size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
324
- if (size > 1) {
459
+ if (!model.hparams.no_alloc) {
460
+ backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
461
+ }
462
+ if (backend_buf_exp_size[i] > 1) {
325
463
  LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
326
464
  ggml_backend_buft_name(buft),
327
- size / 1024.0 / 1024.0);
465
+ backend_buf_exp_size[i] / 1024.0 / 1024.0);
328
466
  }
329
467
  }
330
468
 
@@ -340,9 +478,35 @@ llama_context::llama_context(
340
478
  LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
341
479
  }
342
480
  }
481
+
482
+ // Initialize the full vocabulary token ids for backend samplers.
483
+ {
484
+ const int n_vocab = model.vocab.n_tokens();
485
+
486
+ sampling.token_ids_full_vocab.resize(n_vocab);
487
+ for (int i = 0; i < n_vocab; ++i) {
488
+ sampling.token_ids_full_vocab[i] = i;
489
+ }
490
+ }
343
491
  }
344
492
 
345
493
  llama_context::~llama_context() {
494
+ if (!model.hparams.no_alloc) {
495
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
496
+ ggml_backend_t backend = backend_ptrs[i];
497
+ ggml_backend_buffer_type_t buft = backend_buft[i];
498
+
499
+ const size_t size_exp = backend_buf_exp_size[i];
500
+ const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
501
+ if (size_exp == size_act) {
502
+ LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
503
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
504
+ } else {
505
+ LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
506
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
507
+ }
508
+ }
509
+ }
346
510
  ggml_opt_free(opt_ctx);
347
511
  }
348
512
 
@@ -388,16 +552,12 @@ ggml_backend_sched_t llama_context::get_sched() const {
388
552
  return sched.get();
389
553
  }
390
554
 
391
- ggml_context * llama_context::get_ctx_compute() const {
392
- return ctx_compute.get();
393
- }
394
-
395
555
  uint32_t llama_context::n_ctx() const {
396
556
  return cparams.n_ctx;
397
557
  }
398
558
 
399
- uint32_t llama_context::n_ctx_per_seq() const {
400
- return cparams.n_ctx / cparams.n_seq_max;
559
+ uint32_t llama_context::n_ctx_seq() const {
560
+ return cparams.n_ctx_seq;
401
561
  }
402
562
 
403
563
  uint32_t llama_context::n_batch() const {
@@ -424,26 +584,12 @@ llama_memory_t llama_context::get_memory() const {
424
584
  return memory.get();
425
585
  }
426
586
 
427
- // deprecated
428
- void llama_context::kv_self_defrag_sched() {
429
- if (!memory) {
430
- return;
431
- }
432
-
433
- memory_force_optimize = true;
434
- }
435
-
436
- // deprecated
437
- bool llama_context::kv_self_update(bool optimize) {
587
+ bool llama_context::memory_update(bool optimize) {
438
588
  if (!memory) {
439
589
  return false;
440
590
  }
441
591
 
442
592
  {
443
- // TODO: remove in the future
444
- optimize |= memory_force_optimize;
445
- memory_force_optimize = false;
446
-
447
593
  const auto mctx = memory->init_update(this, optimize);
448
594
  switch (mctx->get_status()) {
449
595
  case LLAMA_MEMORY_STATUS_SUCCESS:
@@ -463,6 +609,11 @@ bool llama_context::kv_self_update(bool optimize) {
463
609
  }
464
610
  }
465
611
 
612
+ // reset the previous graph result to make sure that it won't be reused
613
+ // TODO: change the mctx->apply() to return information if a graph reserve is needed
614
+ // reset the graph result only if the memory module did reset the scheduler
615
+ gf_res_prev->reset();
616
+
466
617
  if (!mctx->apply()) {
467
618
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
619
  }
@@ -475,7 +626,7 @@ bool llama_context::kv_self_update(bool optimize) {
475
626
  throw std::runtime_error("failed to initialize memory context");
476
627
  }
477
628
 
478
- const uint32_t n_seqs = cparams.n_seq_max;
629
+ const uint32_t n_seqs = cparams.n_seq_max;
479
630
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
631
 
481
632
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -492,17 +643,51 @@ enum llama_pooling_type llama_context::pooling_type() const {
492
643
  }
493
644
 
494
645
  float * llama_context::get_logits() {
646
+ output_reorder();
647
+
495
648
  return logits;
496
649
  }
497
650
 
651
+ int64_t llama_context::output_resolve_row(int32_t i) const {
652
+ int64_t j = -1;
653
+
654
+ // support negative indices (last output row)
655
+ if (i < 0) {
656
+ j = n_outputs + i;
657
+ if (j < 0) {
658
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
659
+ }
660
+ } else if ((size_t) i >= output_ids.size()) {
661
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
662
+ } else {
663
+ // use output_ids to translate the batch token index into a row number
664
+ // that holds this token's data.
665
+ j = output_ids[i];
666
+ }
667
+
668
+ if (j < 0) {
669
+ // the batch token was not configured to output anything
670
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
671
+ }
672
+
673
+ if (j >= n_outputs) {
674
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
675
+ }
676
+
677
+ return j;
678
+ }
679
+
498
680
  float * llama_context::get_logits_ith(int32_t i) {
499
681
  int64_t j = -1;
500
682
 
683
+ output_reorder();
684
+
501
685
  try {
502
686
  if (logits == nullptr) {
503
687
  throw std::runtime_error("no logits");
504
688
  }
505
689
 
690
+ // TODO: use output_resolve_row()
506
691
  if (i < 0) {
507
692
  j = n_outputs + i;
508
693
  if (j < 0) {
@@ -534,17 +719,26 @@ float * llama_context::get_logits_ith(int32_t i) {
534
719
  }
535
720
 
536
721
  float * llama_context::get_embeddings() {
722
+ output_reorder();
723
+
537
724
  return embd;
538
725
  }
539
726
 
727
+ llama_token * llama_context::get_sampled_tokens() const{
728
+ return sampling.sampled;
729
+ }
730
+
540
731
  float * llama_context::get_embeddings_ith(int32_t i) {
541
732
  int64_t j = -1;
542
733
 
734
+ output_reorder();
735
+
543
736
  try {
544
737
  if (embd == nullptr) {
545
738
  throw std::runtime_error("no embeddings");
546
739
  }
547
740
 
741
+ // TODO: use output_resolve_row()
548
742
  if (i < 0) {
549
743
  j = n_outputs + i;
550
744
  if (j < 0) {
@@ -564,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
564
758
  throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
565
759
  }
566
760
 
567
- return embd + j*model.hparams.n_embd;
761
+ const uint32_t n_embd_out = model.hparams.get_n_embd_out();
762
+ return embd + j*n_embd_out;
568
763
  } catch (const std::exception & err) {
569
764
  LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
570
765
  #ifndef NDEBUG
@@ -584,6 +779,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
584
779
  return it->second.data();
585
780
  }
586
781
 
782
+ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
783
+ output_reorder();
784
+
785
+ if (sampling.sampled == nullptr) {
786
+ return LLAMA_TOKEN_NULL;
787
+ }
788
+
789
+ try {
790
+ const int64_t row = output_resolve_row(idx);
791
+ GGML_ASSERT(row < (int64_t) sampling.sampled_size);
792
+ return sampling.sampled[row];
793
+ } catch (const std::exception & err) {
794
+ LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
795
+ return LLAMA_TOKEN_NULL;
796
+ }
797
+ }
798
+
799
+ float * llama_context::get_sampled_probs_ith(int32_t idx) {
800
+ output_reorder();
801
+
802
+ if (sampling.probs == nullptr) {
803
+ return nullptr;
804
+ }
805
+
806
+ try {
807
+ const int64_t row = output_resolve_row(idx);
808
+ if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
809
+ return nullptr;
810
+ }
811
+ return sampling.probs + row*model.vocab.n_tokens();
812
+ } catch (const std::exception & err) {
813
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
814
+ return nullptr;
815
+ }
816
+ }
817
+
818
+ float * llama_context::get_sampled_logits_ith(int32_t idx) {
819
+ output_reorder();
820
+
821
+ if (sampling.logits == nullptr) {
822
+ return nullptr;
823
+ }
824
+
825
+ try {
826
+ const int64_t row = output_resolve_row(idx);
827
+ if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
828
+ return nullptr;
829
+ }
830
+ return sampling.logits + row*model.vocab.n_tokens();
831
+ } catch (const std::exception & err) {
832
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
833
+ return nullptr;
834
+ }
835
+ }
836
+
837
+ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
838
+ output_reorder();
839
+
840
+ try {
841
+ const int64_t row = output_resolve_row(idx);
842
+ if (sampling.candidates != nullptr &&
843
+ (size_t) row < sampling.candidates_count.size() &&
844
+ sampling.candidates_count[row] > 0) {
845
+ return sampling.candidates + row*model.vocab.n_tokens();
846
+ }
847
+ } catch (const std::exception & err) {
848
+ // fallback to full vocab list
849
+ }
850
+
851
+ return sampling.token_ids_full_vocab.data();
852
+ }
853
+
854
+ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
855
+ output_reorder();
856
+
857
+ if (sampling.candidates == nullptr) {
858
+ return 0;
859
+ }
860
+
861
+ try {
862
+ const int64_t row = output_resolve_row(idx);
863
+ if ((size_t) row >= sampling.candidates_count.size()) {
864
+ return 0;
865
+ }
866
+ return sampling.candidates_count[row];
867
+ } catch (const std::exception & err) {
868
+ LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
869
+ return 0;
870
+ }
871
+ }
872
+
873
+ size_t llama_context::get_sampled_logits_count(int32_t idx) {
874
+ output_reorder();
875
+
876
+ if (sampling.logits == nullptr) {
877
+ return model.vocab.n_tokens();
878
+ }
879
+
880
+ try {
881
+ const int64_t row = output_resolve_row(idx);
882
+ if ((size_t) row >= sampling.logits_count.size()) {
883
+ return 0;
884
+ }
885
+ return sampling.logits_count[row];
886
+ } catch (const std::exception & err) {
887
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
888
+ return 0;
889
+ }
890
+ }
891
+
892
+ size_t llama_context::get_sampled_probs_count(int32_t idx) {
893
+ output_reorder();
894
+
895
+ if (sampling.probs == nullptr) {
896
+ return 0;
897
+ }
898
+
899
+ try {
900
+ const int64_t row = output_resolve_row(idx);
901
+ if ((size_t) row >= sampling.probs_count.size()) {
902
+ return 0;
903
+ }
904
+ return sampling.probs_count[row];
905
+ } catch (const std::exception & err) {
906
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
907
+ return 0;
908
+ }
909
+ }
910
+
911
+
587
912
  void llama_context::attach_threadpool(
588
913
  ggml_threadpool_t threadpool,
589
914
  ggml_threadpool_t threadpool_batch) {
@@ -640,6 +965,42 @@ void llama_context::set_warmup(bool value) {
640
965
  cparams.warmup = value;
641
966
  }
642
967
 
968
+ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
969
+ LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
970
+
971
+ const bool can_offload =
972
+ sampler &&
973
+ sampler->iface->backend_init &&
974
+ sampler->iface->backend_apply &&
975
+ llama_sampler_chain_n(sampler) > 0;
976
+
977
+ if (sampler && can_offload) {
978
+ ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
979
+ auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
980
+ if (host_buft) {
981
+ buft = host_buft;
982
+ }
983
+
984
+ sampler->iface->backend_init(sampler, buft);
985
+
986
+ sampling.samplers[seq_id] = sampler;
987
+
988
+ return true;
989
+ }
990
+
991
+ if (sampler && !can_offload) {
992
+ LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
993
+
994
+ sampling.samplers.erase(seq_id);
995
+
996
+ return false;
997
+ }
998
+
999
+ sampling.samplers.erase(seq_id);
1000
+
1001
+ return true;
1002
+ }
1003
+
643
1004
  void llama_context::set_adapter_lora(
644
1005
  llama_adapter_lora * adapter,
645
1006
  float scale) {
@@ -678,38 +1039,59 @@ bool llama_context::apply_adapter_cvec(
678
1039
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
679
1040
  }
680
1041
 
681
- llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
1042
+ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
1043
  if (mctx && !mctx->apply()) {
683
1044
  LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
1045
  ret = GGML_STATUS_FAILED;
685
1046
  return nullptr;
686
1047
  }
687
1048
 
688
- auto * gf = graph_init();
689
- if (!gf) {
690
- LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
- ret = GGML_STATUS_FAILED;
692
- return nullptr;
693
- }
1049
+ auto * res = gf_res_prev.get();
1050
+ auto * gf = res->get_gf();
694
1051
 
695
- auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
- if (!res) {
697
- LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
- ret = GGML_STATUS_FAILED;
699
- return nullptr;
700
- }
1052
+ // the new graph parameters
1053
+ // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
1054
+ const auto gparams = graph_params(res, ubatch, mctx, gtype);
701
1055
 
702
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1056
+ if (!graph_reuse_disable && res->can_reuse(gparams)) {
1057
+ //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
703
1058
 
704
- if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
- LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
- ret = GGML_STATUS_ALLOC_FAILED;
707
- return nullptr;
1059
+ n_reused++;
1060
+ } else {
1061
+ res->reset();
1062
+
1063
+ ggml_backend_sched_reset(sched.get());
1064
+ ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1065
+
1066
+ //const auto t_start_us = ggml_time_us();
1067
+
1068
+ gf = model.build_graph(gparams);
1069
+
1070
+ //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
1071
+
1072
+ if (!gf) {
1073
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
1074
+ ret = GGML_STATUS_FAILED;
1075
+ return nullptr;
1076
+ }
1077
+
1078
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
1079
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
1080
+ ret = GGML_STATUS_ALLOC_FAILED;
1081
+ return nullptr;
1082
+ }
708
1083
  }
709
1084
 
710
- res->set_inputs(&ubatch);
1085
+ // set the input data for the input tensors
1086
+ {
1087
+ //const auto t_start_us = ggml_time_us();
1088
+
1089
+ res->set_inputs(&ubatch);
711
1090
 
712
- const auto status = graph_compute(gf, ubatch.n_tokens > 1);
1091
+ //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
1092
+ }
1093
+
1094
+ const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
713
1095
  if (status != GGML_STATUS_SUCCESS) {
714
1096
  LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
1097
  ret = status;
@@ -731,16 +1113,19 @@ int llama_context::encode(const llama_batch & batch_inp) {
731
1113
 
732
1114
  const auto & hparams = model.hparams;
733
1115
 
734
- const int64_t n_embd = hparams.n_embd;
1116
+ const int64_t n_embd = hparams.n_embd_inp();
1117
+ const int64_t n_vocab = model.vocab.n_tokens();
735
1118
 
736
1119
  // note: during encode, we always pass the full sequence starting from pos = 0
737
- if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
1120
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
738
1121
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
1122
  return -1;
740
1123
  }
741
1124
 
742
1125
  const uint32_t n_tokens = balloc->get_n_tokens();
743
1126
 
1127
+ // [TAG_NO_CACHE_PAD]
1128
+ // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
744
1129
  const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745
1130
 
746
1131
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@@ -756,7 +1141,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
756
1141
  n_queued_tokens += n_tokens;
757
1142
 
758
1143
  // reserve output buffer
759
- if (output_reserve(n_tokens) < n_tokens) {
1144
+ if (output_reserve(n_tokens, batch_inp) < n_tokens) {
760
1145
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
761
1146
  return -2;
762
1147
  };
@@ -767,9 +1152,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
767
1152
 
768
1153
  n_outputs = n_tokens;
769
1154
 
770
- ggml_backend_sched_reset(sched.get());
771
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
772
-
773
1155
  const auto causal_attn_org = cparams.causal_attn;
774
1156
 
775
1157
  // always use non-causal attention for encoder graphs
@@ -778,7 +1160,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778
1160
  cparams.causal_attn = false;
779
1161
 
780
1162
  ggml_status status;
781
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
1163
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
782
1164
 
783
1165
  cparams.causal_attn = causal_attn_org;
784
1166
 
@@ -791,10 +1173,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
791
1173
  }
792
1174
  }
793
1175
 
1176
+ auto * t_logits = res->get_logits();
794
1177
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
795
1178
 
1179
+ // extract logits
1180
+ if (logits && t_logits) {
1181
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1182
+ GGML_ASSERT(backend_res != nullptr);
1183
+ GGML_ASSERT(logits != nullptr);
1184
+
1185
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
1186
+ }
1187
+
796
1188
  // extract embeddings
797
- if (t_embd) {
1189
+ if (embd && t_embd) {
798
1190
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
799
1191
  GGML_ASSERT(backend_embd != nullptr);
800
1192
 
@@ -803,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
803
1195
  {
804
1196
  // extract token embeddings
805
1197
  GGML_ASSERT(embd != nullptr);
1198
+ const uint32_t n_embd_out = hparams.get_n_embd_out();
806
1199
 
807
- GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
808
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1200
+ GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
1201
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
809
1202
  } break;
810
1203
  case LLAMA_POOLING_TYPE_MEAN:
811
1204
  case LLAMA_POOLING_TYPE_CLS:
@@ -844,10 +1237,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
844
1237
  }
845
1238
  }
846
1239
 
847
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848
- // overlap with device computation.
849
- ggml_backend_sched_reset(sched.get());
850
-
851
1240
  // TODO: hacky solution
852
1241
  if (model.arch == LLM_ARCH_T5 && t_embd) {
853
1242
  //cross.t_embd = t_embd;
@@ -877,6 +1266,112 @@ int llama_context::encode(const llama_batch & batch_inp) {
877
1266
  return 0;
878
1267
  }
879
1268
 
1269
+ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
1270
+ std::map<llama_seq_id, uint32_t> seq_to_row;
1271
+ // how many output tokens we have seen so far for this ubatch.
1272
+ uint32_t local = 0;
1273
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1274
+ // skip tokens that are not output.
1275
+ if (!ubatch.output[i]) {
1276
+ continue;
1277
+ }
1278
+
1279
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
1280
+ // row_offset is the number of output tokens before this ubatch.
1281
+ seq_to_row[seq_id] = row_offset + local;
1282
+ ++local;
1283
+ }
1284
+ return seq_to_row;
1285
+ }
1286
+
1287
+ static void copy_tensor_async_ints(
1288
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1289
+ llama_token * sampled,
1290
+ size_t sampled_size,
1291
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1292
+ ggml_backend_sched_t sched) {
1293
+ if (sampled == nullptr) {
1294
+ return;
1295
+ }
1296
+
1297
+ for (const auto & [seq_id, tensor] : tensor_map) {
1298
+ auto it = seq_to_row.find(seq_id);
1299
+ if (it == seq_to_row.end()) {
1300
+ continue;
1301
+ }
1302
+
1303
+ const uint32_t row = it->second;
1304
+ GGML_ASSERT(row < sampled_size);
1305
+
1306
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
1307
+
1308
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1309
+ ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
1310
+ }
1311
+ }
1312
+
1313
+ static void copy_tensor_async_floats(
1314
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1315
+ float * dst,
1316
+ size_t stride,
1317
+ std::vector<uint32_t> & counts,
1318
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1319
+ ggml_backend_sched_t sched) {
1320
+ if (dst == nullptr) {
1321
+ return;
1322
+ }
1323
+
1324
+ for (const auto & [seq_id, tensor] : tensor_map) {
1325
+ auto it = seq_to_row.find(seq_id);
1326
+ if (it == seq_to_row.end()) {
1327
+ continue;
1328
+ }
1329
+
1330
+ const uint32_t row = it->second;
1331
+ GGML_ASSERT(row < counts.size());
1332
+
1333
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
1334
+
1335
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1336
+ float * row_ptr = dst + (size_t) row * stride;
1337
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1338
+
1339
+ // Update the actual number of logits/probabilities that were written for this row.
1340
+ counts[row] = ggml_nelements(tensor);
1341
+ }
1342
+ }
1343
+
1344
+ static void copy_tensor_async_candidates(
1345
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1346
+ llama_token * dst,
1347
+ size_t stride,
1348
+ std::vector<uint32_t> & counts,
1349
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1350
+ ggml_backend_sched_t sched) {
1351
+ if (dst == nullptr) {
1352
+ return;
1353
+ }
1354
+
1355
+ for (const auto & [seq_id, tensor] : tensor_map) {
1356
+ auto it = seq_to_row.find(seq_id);
1357
+ if (it == seq_to_row.end()) {
1358
+ continue;
1359
+ }
1360
+
1361
+ const uint32_t row = it->second;
1362
+ GGML_ASSERT(row < counts.size());
1363
+
1364
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
1365
+
1366
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1367
+ llama_token * row_ptr = dst + (size_t) row * stride;
1368
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1369
+
1370
+ // Update the actual number of candidates that were written.
1371
+ counts[row] = ggml_nelements(tensor);
1372
+ }
1373
+ }
1374
+
880
1375
  int llama_context::decode(const llama_batch & batch_inp) {
881
1376
  GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
882
1377
 
@@ -893,13 +1388,40 @@ int llama_context::decode(const llama_batch & batch_inp) {
893
1388
  const auto & vocab = model.vocab;
894
1389
  const auto & hparams = model.hparams;
895
1390
 
896
- const int32_t n_vocab = vocab.n_tokens();
897
- const int64_t n_embd = hparams.n_embd;
1391
+ const int64_t n_vocab = vocab.n_tokens();
1392
+ const int64_t n_embd = hparams.n_embd_inp();
898
1393
 
899
1394
  // when computing embeddings, all tokens are output
900
- const bool output_all = cparams.embeddings;
1395
+ const bool output_all = cparams.embeddings;
1396
+ const bool has_samplers = !sampling.samplers.empty();
1397
+
1398
+ const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
901
1399
 
902
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
1400
+ // TODO: avoid this workaround in the future
1401
+ if (has_samplers && batch_inp.logits) {
1402
+ std::vector<int32_t> seq_output_count(n_seq_max, 0);
1403
+
1404
+ for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
1405
+ if (batch_inp.logits[i] == 0) {
1406
+ continue;
1407
+ }
1408
+
1409
+ const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
1410
+
1411
+ for (int32_t s = 0; s < ns; ++s) {
1412
+ const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
1413
+
1414
+ seq_output_count[seq_id]++;
1415
+ if (seq_output_count[seq_id] > 1) {
1416
+ LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
1417
+ __func__, seq_id, seq_output_count[seq_id]);
1418
+ return -1;
1419
+ }
1420
+ }
1421
+ }
1422
+ }
1423
+
1424
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
903
1425
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
1426
  return -1;
905
1427
  }
@@ -927,11 +1449,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
927
1449
 
928
1450
  // TODO: this clear of the buffer can easily be forgotten - need something better
929
1451
  embd_seq.clear();
1452
+ output_swaps.clear();
930
1453
 
931
1454
  bool did_optimize = false;
932
1455
 
933
- // handle any pending defrags/shifts
934
- kv_self_update(false);
1456
+ // handle any pending shifts/copies
1457
+ memory_update(false);
935
1458
 
936
1459
  llama_memory_context_ptr mctx;
937
1460
 
@@ -956,7 +1479,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
956
1479
  if (!did_optimize) {
957
1480
  did_optimize = true;
958
1481
 
959
- if (kv_self_update(true)) {
1482
+ if (memory_update(true)) {
960
1483
  LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
961
1484
 
962
1485
  continue;
@@ -979,7 +1502,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
979
1502
  }
980
1503
 
981
1504
  // reserve output buffer
982
- if (output_reserve(n_outputs_all) < n_outputs_all) {
1505
+ if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
983
1506
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
984
1507
  return -2;
985
1508
  };
@@ -1005,14 +1528,11 @@ int llama_context::decode(const llama_batch & batch_inp) {
1005
1528
  n_outputs = n_outputs_new;
1006
1529
  }
1007
1530
 
1008
- ggml_backend_sched_reset(sched.get());
1009
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010
-
1011
1531
  ggml_status status;
1012
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1532
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1013
1533
 
1014
1534
  if (!res) {
1015
- // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1535
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
1016
1536
  llama_pos pos_min[LLAMA_MAX_SEQ];
1017
1537
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1018
1538
  pos_min[s] = std::numeric_limits<llama_pos>::max();
@@ -1029,7 +1549,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1029
1549
  continue;
1030
1550
  }
1031
1551
 
1032
- LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1552
+ LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1033
1553
 
1034
1554
  memory->seq_rm(s, pos_min[s], -1);
1035
1555
  }
@@ -1055,7 +1575,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
1055
1575
  }
1056
1576
 
1057
1577
  // extract logits
1058
- if (t_logits && n_outputs > 0) {
1578
+ // For multi-sequence batches that mix backend samplers and CPU sampler
1579
+ // this is currently inefficient as we copy all logits even for the
1580
+ // backend sampled tokens.
1581
+ if (logits && t_logits && n_outputs > 0) {
1059
1582
  ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1060
1583
  GGML_ASSERT(backend_res != nullptr);
1061
1584
  GGML_ASSERT(logits != nullptr);
@@ -1070,7 +1593,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1070
1593
  }
1071
1594
 
1072
1595
  // extract embeddings
1073
- if (t_embd && n_outputs > 0) {
1596
+ if (embd && t_embd && n_outputs > 0) {
1074
1597
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1075
1598
  GGML_ASSERT(backend_embd != nullptr);
1076
1599
 
@@ -1079,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
1079
1602
  {
1080
1603
  // extract token embeddings
1081
1604
  GGML_ASSERT(embd != nullptr);
1082
- float * embd_out = embd + n_outputs_prev*n_embd;
1605
+ const uint32_t n_embd_out = hparams.get_n_embd_out();
1606
+ float * embd_out = embd + n_outputs_prev*n_embd_out;
1083
1607
 
1084
1608
  if (n_outputs) {
1085
1609
  GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1086
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1087
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
1610
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
1611
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
1088
1612
  }
1089
1613
  } break;
1090
1614
  case LLAMA_POOLING_TYPE_MEAN:
@@ -1124,6 +1648,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
1124
1648
  }
1125
1649
  }
1126
1650
 
1651
+ // This flag indicates whether a backend sampler has actually sampled a specific
1652
+ // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
1653
+ const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
1654
+
1655
+ if (has_samplers && has_sampled) {
1656
+ const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
1657
+ const auto stride = n_vocab;
1658
+
1659
+ // async copy the sampling data from the backend to the host
1660
+ copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
1661
+
1662
+ copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
1663
+ copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
1664
+ copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
1665
+ }
1666
+
1127
1667
  n_outputs_prev += n_outputs;
1128
1668
  } while (mctx->next());
1129
1669
 
@@ -1148,10 +1688,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1148
1688
 
1149
1689
  // make the outputs have the same order they had in the user-provided batch
1150
1690
  // note: this is mostly relevant for recurrent models atm
1151
- if (!sorted_output) {
1152
- const uint32_t n_vocab = model.vocab.n_tokens();
1153
- const uint64_t n_embd = model.hparams.n_embd;
1154
-
1691
+ if (!sorted_output && n_outputs > 1) {
1155
1692
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1156
1693
 
1157
1694
  // TODO: is there something more efficient which also minimizes swaps?
@@ -1167,16 +1704,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
1167
1704
  continue;
1168
1705
  }
1169
1706
  std::swap(out_ids[i], out_ids[j_min]);
1170
- if (logits_size > 0) {
1171
- for (uint32_t k = 0; k < n_vocab; k++) {
1172
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1173
- }
1174
- }
1175
- if (embd_size > 0) {
1176
- for (uint32_t k = 0; k < n_embd; k++) {
1177
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1178
- }
1179
- }
1707
+
1708
+ // remember the swaps and apply them lazily upon logits/embeddings access
1709
+ output_swaps.push_back({ i, j_min });
1180
1710
  }
1181
1711
 
1182
1712
  std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1190,10 +1720,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1190
1720
  // wait for the computation to finish (automatically done when obtaining the model output)
1191
1721
  //synchronize();
1192
1722
 
1193
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1194
- // overlap with device computation.
1195
- ggml_backend_sched_reset(sched.get());
1196
-
1197
1723
  return 0;
1198
1724
  }
1199
1725
 
@@ -1201,15 +1727,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
1201
1727
  // output
1202
1728
  //
1203
1729
 
1204
- uint32_t llama_context::output_reserve(int32_t n_outputs) {
1730
+ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
1205
1731
  const auto & hparams = model.hparams;
1206
1732
  const auto & vocab = model.vocab;
1207
1733
 
1208
1734
  const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
1209
1735
 
1210
- const auto n_batch = cparams.n_batch;
1211
- const auto n_vocab = vocab.n_tokens();
1212
- const auto n_embd = hparams.n_embd;
1736
+ const auto n_batch = cparams.n_batch;
1737
+ const auto n_vocab = vocab.n_tokens();
1738
+ const auto n_embd_out = hparams.get_n_embd_out();
1213
1739
 
1214
1740
  bool has_logits = true;
1215
1741
  bool has_embd = cparams.embeddings;
@@ -1220,8 +1746,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1220
1746
  has_embd = true;
1221
1747
  }
1222
1748
 
1223
- logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1224
- embd_size = has_embd ? n_embd*n_outputs_max : 0;
1749
+ // Check which sampling modes are needed for the current batch.
1750
+ // TODO: avoid this branching by working with the worst-case
1751
+ bool has_sampling = false;
1752
+ bool cpu_logits = false;
1753
+
1754
+ if (batch.logits) {
1755
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
1756
+ if (!batch.logits[i]) {
1757
+ continue;
1758
+ }
1759
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
1760
+ llama_seq_id seq_id = batch.seq_id[i][j];
1761
+ if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
1762
+ has_sampling = true;
1763
+ } else {
1764
+ cpu_logits = true;
1765
+ }
1766
+ }
1767
+ }
1768
+ } else {
1769
+ // When batch.logits is nullptr (when loading state with a dummy batch),
1770
+ // allocate CPU logits.
1771
+ cpu_logits = true;
1772
+ }
1773
+
1774
+ size_t backend_float_count = 0;
1775
+ size_t backend_token_count = 0;
1776
+
1777
+ // Allocate CPU logits buffer only if needed by sequences in this batch
1778
+ logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
1779
+ embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
1780
+
1781
+ // TODO: avoid this branching by working with the worst-case
1782
+ if (!has_sampling) {
1783
+ sampling.logits_size = 0;
1784
+ sampling.probs_size = 0;
1785
+ sampling.sampled_size = 0;
1786
+ sampling.candidates_size = 0;
1787
+ } else {
1788
+ sampling.logits_size = n_vocab*n_outputs_max;
1789
+ sampling.probs_size = n_vocab*n_outputs_max;
1790
+ sampling.sampled_size = n_outputs_max;
1791
+ sampling.candidates_size = n_vocab*n_outputs_max;
1792
+
1793
+ backend_float_count = sampling.logits_size + sampling.probs_size;
1794
+ backend_token_count = sampling.sampled_size + sampling.candidates_size;
1795
+ }
1225
1796
 
1226
1797
  if (output_ids.empty()) {
1227
1798
  // init, never resized afterwards
@@ -1229,7 +1800,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1229
1800
  }
1230
1801
 
1231
1802
  const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
1232
- const size_t new_size = (logits_size + embd_size) * sizeof(float);
1803
+ const size_t new_size =
1804
+ (logits_size + embd_size + backend_float_count) * sizeof(float) +
1805
+ ( backend_token_count) * sizeof(llama_token);
1233
1806
 
1234
1807
  // alloc only when more than the current capacity is required
1235
1808
  // TODO: also consider shrinking the buffer
@@ -1237,8 +1810,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1237
1810
  if (buf_output) {
1238
1811
  #ifndef NDEBUG
1239
1812
  // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1240
- LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1813
+ LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1241
1814
  #endif
1815
+ synchronize();
1816
+
1817
+ // TODO: not needed?
1242
1818
  buf_output = nullptr;
1243
1819
  logits = nullptr;
1244
1820
  embd = nullptr;
@@ -1260,8 +1836,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1260
1836
 
1261
1837
  float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
1262
1838
 
1263
- logits = has_logits ? output_base : nullptr;
1264
- embd = has_embd ? output_base + logits_size : nullptr;
1839
+ logits = nullptr;
1840
+ embd = nullptr;
1841
+
1842
+ size_t offset = 0;
1843
+ uint8_t * base = (uint8_t *) output_base;
1844
+
1845
+ logits = (has_logits && cpu_logits) ? output_base : nullptr;
1846
+ offset += logits_size * sizeof(float);
1847
+
1848
+ embd = has_embd ? (float *) (base + offset) : nullptr;
1849
+ offset += embd_size * sizeof(float);
1850
+
1851
+ sampling.logits = nullptr;
1852
+ sampling.probs = nullptr;
1853
+ sampling.sampled = nullptr;
1854
+ sampling.candidates = nullptr;
1855
+
1856
+ if (has_sampling) {
1857
+ sampling.logits = (float *) (base + offset);
1858
+ offset += sampling.logits_size * sizeof(float);
1859
+
1860
+ sampling.probs = (float *) (base + offset);
1861
+ offset += sampling.probs_size * sizeof(float);
1862
+
1863
+ sampling.sampled = (llama_token *) (base + offset);
1864
+ offset += sampling.sampled_size * sizeof(llama_token);
1865
+
1866
+ sampling.candidates = (llama_token *) (base + offset);
1867
+ offset += sampling.candidates_size * sizeof(llama_token);
1868
+
1869
+ // The count vectors keep track of the actual number of logits/probs/candidates
1870
+ // copied from the backend for each output row.
1871
+
1872
+ sampling.logits_count.resize(n_outputs_max);
1873
+ sampling.probs_count.resize(n_outputs_max);
1874
+ sampling.candidates_count.resize(n_outputs_max);
1875
+
1876
+ std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
1877
+ std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
1878
+ std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
1879
+
1880
+ std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
1881
+ }
1265
1882
 
1266
1883
  // set all ids as invalid (negative)
1267
1884
  std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1271,36 +1888,98 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1271
1888
  return n_outputs_max;
1272
1889
  }
1273
1890
 
1891
+ void llama_context::output_reorder() {
1892
+ const uint64_t n_vocab = model.vocab.n_tokens();
1893
+ const uint64_t n_embd = model.hparams.n_embd;
1894
+
1895
+ for (size_t s = 0; s < output_swaps.size(); ++s) {
1896
+ const uint64_t i0 = output_swaps[s].i0;
1897
+ const uint64_t i1 = output_swaps[s].i1;
1898
+
1899
+ if (logits_size > 0) {
1900
+ for (uint64_t k = 0; k < n_vocab; k++) {
1901
+ std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1902
+ }
1903
+ }
1904
+
1905
+ if (embd_size > 0) {
1906
+ for (uint64_t k = 0; k < n_embd; k++) {
1907
+ std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1908
+ }
1909
+ }
1910
+
1911
+ if (sampling.logits && sampling.logits_size > 0) {
1912
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1913
+ std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
1914
+ }
1915
+ }
1916
+
1917
+ if (sampling.probs && sampling.probs_size > 0) {
1918
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1919
+ std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
1920
+ }
1921
+ }
1922
+
1923
+ if (sampling.candidates && sampling.candidates_size > 0) {
1924
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1925
+ std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
1926
+ }
1927
+ }
1928
+
1929
+ if (sampling.sampled && sampling.sampled_size > 0) {
1930
+ std::swap(sampling.sampled[i0], sampling.sampled[i1]);
1931
+ }
1932
+
1933
+ if (!sampling.logits_count.empty()) {
1934
+ std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
1935
+ }
1936
+
1937
+ if (!sampling.probs_count.empty()) {
1938
+ std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
1939
+ }
1940
+
1941
+ if (!sampling.candidates_count.empty()) {
1942
+ std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
1943
+ }
1944
+ }
1945
+
1946
+ output_swaps.clear();
1947
+ }
1948
+
1274
1949
  //
1275
1950
  // graph
1276
1951
  //
1277
1952
 
1278
- int32_t llama_context::graph_max_nodes() const {
1279
- return std::max<int32_t>(65536, 5*model.n_tensors());
1953
+ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
1954
+ if (model.arch == LLM_ARCH_QWEN3NEXT) {
1955
+ return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
1956
+ }
1957
+ uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
1958
+ res += model.n_lora_nodes;
1959
+ return res;
1280
1960
  }
1281
1961
 
1282
- ggml_cgraph * llama_context::graph_init() {
1283
- ggml_init_params params = {
1284
- /*.mem_size =*/ buf_compute_meta.size(),
1285
- /*.mem_buffer =*/ buf_compute_meta.data(),
1286
- /*.no_alloc =*/ true,
1287
- };
1288
-
1289
- ctx_compute.reset(ggml_init(params));
1290
-
1291
- return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1962
+ llm_graph_result * llama_context::get_gf_res_reserve() const {
1963
+ return static_cast<llm_graph_result *>(gf_res_reserve.get());
1292
1964
  }
1293
1965
 
1294
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1966
+ ggml_cgraph * llama_context::graph_reserve(
1967
+ uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
1295
1968
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1969
+ GGML_ASSERT(n_outputs >= 1);
1296
1970
 
1297
1971
  if (n_tokens % n_seqs != 0) {
1298
1972
  n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1299
- n_outputs = std::min(n_outputs, n_tokens);
1973
+ n_outputs = std::max(n_outputs, n_tokens);
1300
1974
 
1301
1975
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
1976
  }
1303
1977
 
1978
+ ggml_backend_sched_reset(sched.get());
1979
+
1980
+ // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
1981
+ gf_res_prev->reset();
1982
+
1304
1983
  // store the n_outputs as it is, and restore it afterwards
1305
1984
  // TODO: not sure if needed, might simplify in the future by removing this
1306
1985
  const auto save_n_outputs = this->n_outputs;
@@ -1310,20 +1989,34 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1310
1989
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311
1990
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1312
1991
 
1313
- auto * gf = graph_init();
1314
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1992
+ // set one output token per sequence in order to activate all backend samplers
1993
+ std::vector<llama_seq_id> seq_ids(n_seqs);
1994
+ for (uint32_t i = 0; i < n_seqs; ++i) {
1995
+ seq_ids[i] = i;
1996
+ ubatch.n_seq_id[i] = 1;
1997
+ ubatch.seq_id[i] = &seq_ids[i];
1998
+ ubatch.output[i] = true;
1999
+ }
1315
2000
 
1316
- this->n_outputs = save_n_outputs;
2001
+ auto * res = gf_res_reserve.get();
1317
2002
 
1318
- if (!res) {
1319
- LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320
- return nullptr;
1321
- }
2003
+ const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
1322
2004
 
1323
- ggml_backend_sched_reset(sched.get());
2005
+ res->reset();
2006
+
2007
+ auto * gf = model.build_graph(gparams);
2008
+
2009
+ this->n_outputs = save_n_outputs;
1324
2010
 
1325
2011
  // initialize scheduler with the specified graph
1326
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
2012
+ if (split_only) {
2013
+ if (sizes) {
2014
+ ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
2015
+ } else {
2016
+ ggml_backend_sched_split_graph(sched.get(), gf);
2017
+ }
2018
+ } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
2019
+ GGML_ASSERT(!sizes);
1327
2020
  LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1328
2021
  return nullptr;
1329
2022
  }
@@ -1331,28 +2024,28 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1331
2024
  return gf;
1332
2025
  }
1333
2026
 
1334
- llm_graph_result_ptr llama_context::graph_build(
1335
- ggml_context * ctx,
1336
- ggml_cgraph * gf,
1337
- const llama_ubatch & ubatch,
1338
- llm_graph_type gtype,
1339
- const llama_memory_context_i * mctx) {
1340
- return model.build_graph(
1341
- {
1342
- /*.ctx =*/ ctx,
1343
- /*.arch =*/ model.arch,
1344
- /*.hparams =*/ model.hparams,
1345
- /*.cparams =*/ cparams,
1346
- /*.ubatch =*/ ubatch,
1347
- /*.sched =*/ sched.get(),
1348
- /*.backend_cpu =*/ backend_cpu,
1349
- /*.cvec =*/ &cvec,
1350
- /*.loras =*/ &loras,
1351
- /*.mctx =*/ mctx,
1352
- /*.cross =*/ &cross,
1353
- /*.n_outputs =*/ n_outputs,
1354
- /*.cb =*/ graph_get_cb(),
1355
- }, gf, gtype);
2027
+ llm_graph_params llama_context::graph_params(
2028
+ llm_graph_result * res,
2029
+ const llama_ubatch & ubatch,
2030
+ const llama_memory_context_i * mctx,
2031
+ llm_graph_type gtype) const {
2032
+ return {
2033
+ /*.arch =*/ model.arch,
2034
+ /*.hparams =*/ model.hparams,
2035
+ /*.cparams =*/ cparams,
2036
+ /*.ubatch =*/ ubatch,
2037
+ /*.gtype =*/ gtype,
2038
+ /*.sched =*/ sched.get(),
2039
+ /*.backend_cpu =*/ backend_cpu,
2040
+ /*.cvec =*/ &cvec,
2041
+ /*.loras =*/ &loras,
2042
+ /*.mctx =*/ mctx,
2043
+ /*.cross =*/ &cross,
2044
+ /*.samplers =*/ sampling.samplers,
2045
+ /*.n_outputs =*/ n_outputs,
2046
+ /*.cb =*/ graph_get_cb(),
2047
+ /*.res =*/ res,
2048
+ };
1356
2049
  }
1357
2050
 
1358
2051
  ggml_status llama_context::graph_compute(
@@ -1364,7 +2057,9 @@ ggml_status llama_context::graph_compute(
1364
2057
  if (backend_cpu != nullptr) {
1365
2058
  auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
1366
2059
  auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
1367
- set_threadpool_fn(backend_cpu, tp);
2060
+ if (set_threadpool_fn) {
2061
+ set_threadpool_fn(backend_cpu, tp);
2062
+ }
1368
2063
  }
1369
2064
 
1370
2065
  // set the number of threads for all the backends
@@ -1399,7 +2094,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
1399
2094
 
1400
2095
  // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1401
2096
  // FIXME: fix in ggml_backend_sched
1402
- const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
2097
+ const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
1403
2098
  if (ubatch.n_tokens < 32 || full_offload) {
1404
2099
  if (il != -1 && strcmp(name, "norm") == 0) {
1405
2100
  const auto & dev_layer = model.dev_layer(il);
@@ -1583,30 +2278,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1583
2278
  }
1584
2279
  }
1585
2280
 
1586
- size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
2281
+ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
1587
2282
  llama_io_write_dummy io;
1588
2283
  try {
1589
- return state_seq_write_data(io, seq_id);
2284
+ return state_seq_write_data(io, seq_id, flags);
1590
2285
  } catch (const std::exception & err) {
1591
2286
  LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1592
2287
  return 0;
1593
2288
  }
1594
2289
  }
1595
2290
 
1596
- size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
2291
+ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
1597
2292
  llama_io_write_buffer io(dst, size);
1598
2293
  try {
1599
- return state_seq_write_data(io, seq_id);
2294
+ return state_seq_write_data(io, seq_id, flags);
1600
2295
  } catch (const std::exception & err) {
1601
2296
  LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1602
2297
  return 0;
1603
2298
  }
1604
2299
  }
1605
2300
 
1606
- size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
2301
+ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
1607
2302
  llama_io_read_buffer io(src, size);
1608
2303
  try {
1609
- return state_seq_read_data(io, seq_id);
2304
+ return state_seq_read_data(io, seq_id, flags);
1610
2305
  } catch (const std::exception & err) {
1611
2306
  LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1612
2307
  return 0;
@@ -1704,7 +2399,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
1704
2399
  {
1705
2400
  const size_t state_size = file.size() - file.tell();
1706
2401
  llama_io_read_file io(&file);
1707
- const size_t nread = state_seq_read_data(io, seq_id);
2402
+ const size_t nread = state_seq_read_data(io, seq_id, 0);
1708
2403
  if (!nread) {
1709
2404
  LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1710
2405
  return 0;
@@ -1728,7 +2423,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
1728
2423
 
1729
2424
  // save the context state using stream saving
1730
2425
  llama_io_write_file io(&file);
1731
- state_seq_write_data(io, seq_id);
2426
+ state_seq_write_data(io, seq_id, 0);
1732
2427
 
1733
2428
  const size_t res = file.tell();
1734
2429
  GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@@ -1802,8 +2497,11 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1802
2497
  }
1803
2498
  }
1804
2499
 
2500
+ // TODO: handle sampling buffers and samplers state ?
2501
+ // https://github.com/ggml-org/llama.cpp/pull/17004
2502
+
1805
2503
  if (memory != nullptr) {
1806
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2504
+ LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1807
2505
  memory->state_write(io);
1808
2506
  }
1809
2507
 
@@ -1834,7 +2532,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1834
2532
  auto n_outputs = this->n_outputs;
1835
2533
  io.read_to(&n_outputs, sizeof(n_outputs));
1836
2534
 
1837
- if (n_outputs > output_reserve(n_outputs)) {
2535
+ // Create a dummy batch for state loading.
2536
+ llama_batch dummy_batch = {};
2537
+ dummy_batch.n_tokens = 0;
2538
+ if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
1838
2539
  throw std::runtime_error("could not reserve outputs");
1839
2540
  }
1840
2541
 
@@ -1888,8 +2589,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1888
2589
  }
1889
2590
  }
1890
2591
 
2592
+ // TODO: handle sampling buffers and samplers state ?
2593
+ // https://github.com/ggml-org/llama.cpp/pull/17004
2594
+
1891
2595
  if (memory) {
1892
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2596
+ LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
1893
2597
 
1894
2598
  memory->state_read(io);
1895
2599
  }
@@ -1897,21 +2601,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1897
2601
  return io.n_bytes();
1898
2602
  }
1899
2603
 
1900
- size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2604
+ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1901
2605
  GGML_UNUSED(seq_id);
1902
2606
 
1903
2607
  if (memory) {
1904
- memory->state_write(io, seq_id);
2608
+ memory->state_write(io, seq_id, flags);
1905
2609
  }
1906
2610
 
1907
2611
  return io.n_bytes();
1908
2612
  }
1909
2613
 
1910
- size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2614
+ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1911
2615
  GGML_UNUSED(seq_id);
1912
2616
 
1913
2617
  if (memory) {
1914
- memory->state_read(io, seq_id);
2618
+ memory->state_read(io, seq_id, flags);
1915
2619
  }
1916
2620
 
1917
2621
  return io.n_bytes();
@@ -1930,6 +2634,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
1930
2634
  data.t_eval_ms = 1e-3 * t_eval_us;
1931
2635
  data.n_p_eval = std::max(1, n_p_eval);
1932
2636
  data.n_eval = std::max(1, n_eval);
2637
+ data.n_reused = std::max(0, n_reused);
1933
2638
 
1934
2639
  return data;
1935
2640
  }
@@ -1938,6 +2643,33 @@ void llama_context::perf_reset() {
1938
2643
  t_start_us = ggml_time_us();
1939
2644
  t_eval_us = n_eval = 0;
1940
2645
  t_p_eval_us = n_p_eval = 0;
2646
+ n_reused = 0;
2647
+ }
2648
+
2649
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
2650
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
2651
+ for (const auto & [buft, size] : model.memory_breakdown()) {
2652
+ ret[buft].model += size;
2653
+ }
2654
+ if (memory) {
2655
+ for (const auto & [buft, size] : memory->memory_breakdown()) {
2656
+ ret[buft].context += size;
2657
+ }
2658
+ }
2659
+ if (model.hparams.no_alloc) {
2660
+ for (size_t i = 0; i < backends.size(); ++i) {
2661
+ ggml_backend_t backend = backends[i].get();
2662
+ ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
2663
+ ret[buft].compute += backend_buf_exp_size[i];
2664
+ }
2665
+ } else {
2666
+ for (const auto & backend_ptr : backends) {
2667
+ ggml_backend_t backend = backend_ptr.get();
2668
+ ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
2669
+ ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
2670
+ }
2671
+ }
2672
+ return ret;
1941
2673
  }
1942
2674
 
1943
2675
  //
@@ -1972,7 +2704,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
1972
2704
  opt_params.opt_period = n_batch / n_ubatch;
1973
2705
  opt_params.get_opt_pars = lopt_params.get_opt_pars;
1974
2706
  opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1975
-
2707
+ opt_params.optimizer = lopt_params.optimizer_type;
1976
2708
  opt_ctx = ggml_opt_init(opt_params);
1977
2709
 
1978
2710
  llama_opt_param_filter param_filter = lopt_params.param_filter;
@@ -2028,7 +2760,7 @@ void llama_context::opt_epoch_iter(
2028
2760
  batch.logits [pos_batch] = true;
2029
2761
  }
2030
2762
 
2031
- if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2763
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2032
2764
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2033
2765
  return;
2034
2766
  }
@@ -2048,7 +2780,7 @@ void llama_context::opt_epoch_iter(
2048
2780
  }
2049
2781
 
2050
2782
  // reserve output buffer
2051
- if (output_reserve(n_outputs_all) < n_outputs_all) {
2783
+ if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
2052
2784
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
2053
2785
  GGML_ABORT("TODO: handle this error");
2054
2786
  };
@@ -2064,8 +2796,13 @@ void llama_context::opt_epoch_iter(
2064
2796
  break;
2065
2797
  }
2066
2798
 
2067
- auto * gf = graph_init();
2068
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2799
+ auto * res = gf_res_prev.get();
2800
+
2801
+ const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2802
+
2803
+ res->reset();
2804
+
2805
+ auto * gf = model.build_graph(gparams);
2069
2806
 
2070
2807
  struct ggml_context * ctx_compute_opt;
2071
2808
  {
@@ -2167,12 +2904,13 @@ llama_context_params llama_context_default_params() {
2167
2904
  /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2168
2905
  /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2169
2906
  /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2907
+ /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
2170
2908
  /*.rope_freq_base =*/ 0.0f,
2171
2909
  /*.rope_freq_scale =*/ 0.0f,
2172
2910
  /*.yarn_ext_factor =*/ -1.0f,
2173
- /*.yarn_attn_factor =*/ 1.0f,
2174
- /*.yarn_beta_fast =*/ 32.0f,
2175
- /*.yarn_beta_slow =*/ 1.0f,
2911
+ /*.yarn_attn_factor =*/ -1.0f,
2912
+ /*.yarn_beta_fast =*/ -1.0f,
2913
+ /*.yarn_beta_slow =*/ -1.0f,
2176
2914
  /*.yarn_orig_ctx =*/ 0,
2177
2915
  /*.defrag_thold =*/ -1.0f,
2178
2916
  /*.cb_eval =*/ nullptr,
@@ -2183,10 +2921,12 @@ llama_context_params llama_context_default_params() {
2183
2921
  /*.abort_callback_data =*/ nullptr,
2184
2922
  /*.embeddings =*/ false,
2185
2923
  /*.offload_kqv =*/ true,
2186
- /*.flash_attn =*/ false,
2187
2924
  /*.no_perf =*/ true,
2188
2925
  /*.op_offload =*/ true,
2189
2926
  /*.swa_full =*/ true,
2927
+ /*.kv_unified =*/ false,
2928
+ /*.sampler =*/ nullptr,
2929
+ /*.n_sampler =*/ 0,
2190
2930
  };
2191
2931
 
2192
2932
  return result;
@@ -2210,16 +2950,41 @@ llama_context * llama_init_from_model(
2210
2950
  return nullptr;
2211
2951
  }
2212
2952
 
2213
- if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2953
+ if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
2214
2954
  LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2215
- params.flash_attn = false;
2955
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2956
+ }
2957
+
2958
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2959
+ const uint32_t blck_size = ggml_blck_size(params.type_k);
2960
+ if (model->hparams.n_embd_head_k % blck_size != 0) {
2961
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2962
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2963
+ return nullptr;
2964
+ }
2216
2965
  }
2217
2966
 
2218
- if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2967
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2968
+ const uint32_t blck_size = ggml_blck_size(params.type_v);
2969
+ if (model->hparams.n_embd_head_v % blck_size != 0) {
2970
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2971
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2972
+ return nullptr;
2973
+ }
2974
+ }
2975
+
2976
+ if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
2219
2977
  LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2220
2978
  return nullptr;
2221
2979
  }
2222
2980
 
2981
+ if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
2982
+ params.pooling_type != model->hparams.pooling_type) {
2983
+ //user-specified pooling-type is different from the model default
2984
+ LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
2985
+ model->hparams.pooling_type, params.pooling_type);
2986
+ }
2987
+
2223
2988
  try {
2224
2989
  auto * ctx = new llama_context(*model, params);
2225
2990
  return ctx;
@@ -2245,6 +3010,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
2245
3010
  return ctx->n_ctx();
2246
3011
  }
2247
3012
 
3013
+ uint32_t llama_n_ctx_seq(const llama_context * ctx) {
3014
+ return ctx->n_ctx_seq();
3015
+ }
3016
+
2248
3017
  uint32_t llama_n_batch(const llama_context * ctx) {
2249
3018
  return ctx->n_batch();
2250
3019
  }
@@ -2261,16 +3030,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2261
3030
  return &ctx->get_model();
2262
3031
  }
2263
3032
 
2264
- // deprecated
2265
- llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2266
- return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2267
- }
2268
-
2269
- // deprecated
2270
- void llama_kv_self_update(llama_context * ctx) {
2271
- ctx->kv_self_update(false);
2272
- }
2273
-
2274
3033
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
2275
3034
  return ctx->pooling_type();
2276
3035
  }
@@ -2327,7 +3086,15 @@ float * llama_get_logits(llama_context * ctx) {
2327
3086
  float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
2328
3087
  ctx->synchronize();
2329
3088
 
2330
- return ctx->get_logits_ith(i);
3089
+ float * res = nullptr;
3090
+
3091
+ res = ctx->get_sampled_logits_ith(i);
3092
+
3093
+ if (!res) {
3094
+ res = ctx->get_logits_ith(i);
3095
+ }
3096
+
3097
+ return res;
2331
3098
  }
2332
3099
 
2333
3100
  float * llama_get_embeddings(llama_context * ctx) {
@@ -2348,6 +3115,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
2348
3115
  return ctx->get_embeddings_seq(seq_id);
2349
3116
  }
2350
3117
 
3118
+ bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
3119
+ return ctx->set_sampler(seq_id, smpl);
3120
+ }
3121
+
3122
+ llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
3123
+ ctx->synchronize();
3124
+
3125
+ return ctx->get_sampled_token_ith(i);
3126
+ }
3127
+
3128
+ float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
3129
+ ctx->synchronize();
3130
+
3131
+ return ctx->get_sampled_probs_ith(i);
3132
+ }
3133
+
3134
+ float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
3135
+ ctx->synchronize();
3136
+
3137
+ return ctx->get_sampled_logits_ith(i);
3138
+ }
3139
+
3140
+ llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
3141
+ ctx->synchronize();
3142
+
3143
+ return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
3144
+ }
3145
+
3146
+ uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
3147
+ ctx->synchronize();
3148
+
3149
+ return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
3150
+ }
3151
+
3152
+ uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
3153
+ ctx->synchronize();
3154
+
3155
+ return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
3156
+ }
3157
+
3158
+ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
3159
+ ctx->synchronize();
3160
+
3161
+ return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
3162
+ }
3163
+
2351
3164
  // llama adapter API
2352
3165
 
2353
3166
  int32_t llama_set_adapter_lora(
@@ -2488,168 +3301,6 @@ bool llama_memory_can_shift(llama_memory_t mem) {
2488
3301
  return mem->get_can_shift();
2489
3302
  }
2490
3303
 
2491
- //
2492
- // kv cache
2493
- //
2494
-
2495
- // deprecated
2496
- int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2497
- const auto * kv = llama_get_memory(ctx);
2498
- if (!kv) {
2499
- return 0;
2500
- }
2501
-
2502
- int32_t res = 0;
2503
-
2504
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2505
- const llama_pos p0 = kv->seq_pos_min(s);
2506
- const llama_pos p1 = kv->seq_pos_max(s);
2507
-
2508
- if (p0 >= 0) {
2509
- res += (p1 - p0) + 1;
2510
- }
2511
- }
2512
-
2513
- return res;
2514
- }
2515
-
2516
- // deprecated
2517
- // note: this is the same as above - will be removed anyway, so it's ok
2518
- int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2519
- const auto * kv = llama_get_memory(ctx);
2520
- if (!kv) {
2521
- return 0;
2522
- }
2523
-
2524
- int32_t res = 0;
2525
-
2526
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2527
- const llama_pos p0 = kv->seq_pos_min(s);
2528
- const llama_pos p1 = kv->seq_pos_max(s);
2529
-
2530
- if (p0 >= 0) {
2531
- res += (p1 - p0) + 1;
2532
- }
2533
- }
2534
-
2535
- return res;
2536
- }
2537
-
2538
- // deprecated
2539
- void llama_kv_self_clear(llama_context * ctx) {
2540
- auto * kv = llama_get_memory(ctx);
2541
- if (!kv) {
2542
- return;
2543
- }
2544
-
2545
- llama_memory_clear(kv, true);
2546
- }
2547
-
2548
- // deprecated
2549
- bool llama_kv_self_seq_rm(
2550
- llama_context * ctx,
2551
- llama_seq_id seq_id,
2552
- llama_pos p0,
2553
- llama_pos p1) {
2554
- auto * kv = llama_get_memory(ctx);
2555
- if (!kv) {
2556
- return true;
2557
- }
2558
-
2559
- return llama_memory_seq_rm(kv, seq_id, p0, p1);
2560
- }
2561
-
2562
- // deprecated
2563
- void llama_kv_self_seq_cp(
2564
- llama_context * ctx,
2565
- llama_seq_id seq_id_src,
2566
- llama_seq_id seq_id_dst,
2567
- llama_pos p0,
2568
- llama_pos p1) {
2569
- auto * kv = llama_get_memory(ctx);
2570
- if (!kv) {
2571
- return;
2572
- }
2573
-
2574
- llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2575
- }
2576
-
2577
- // deprecated
2578
- void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2579
- auto * kv = llama_get_memory(ctx);
2580
- if (!kv) {
2581
- return;
2582
- }
2583
-
2584
- llama_memory_seq_keep(kv, seq_id);
2585
- }
2586
-
2587
- // deprecated
2588
- void llama_kv_self_seq_add(
2589
- llama_context * ctx,
2590
- llama_seq_id seq_id,
2591
- llama_pos p0,
2592
- llama_pos p1,
2593
- llama_pos delta) {
2594
- auto * kv = llama_get_memory(ctx);
2595
- if (!kv) {
2596
- return;
2597
- }
2598
-
2599
- llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2600
- }
2601
-
2602
- // deprecated
2603
- void llama_kv_self_seq_div(
2604
- llama_context * ctx,
2605
- llama_seq_id seq_id,
2606
- llama_pos p0,
2607
- llama_pos p1,
2608
- int d) {
2609
- auto * kv = llama_get_memory(ctx);
2610
- if (!kv) {
2611
- return;
2612
- }
2613
-
2614
- llama_memory_seq_div(kv, seq_id, p0, p1, d);
2615
- }
2616
-
2617
- // deprecated
2618
- llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2619
- auto * kv = llama_get_memory(ctx);
2620
- if (!kv) {
2621
- return -1;
2622
- }
2623
-
2624
- return llama_memory_seq_pos_min(kv, seq_id);
2625
- }
2626
-
2627
- // deprecated
2628
- llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2629
- auto * kv = llama_get_memory(ctx);
2630
- if (!kv) {
2631
- return -1;
2632
- }
2633
-
2634
- return llama_memory_seq_pos_max(kv, seq_id);
2635
- }
2636
-
2637
- // deprecated
2638
- void llama_kv_self_defrag(llama_context * ctx) {
2639
- // force defrag
2640
- ctx->kv_self_defrag_sched();
2641
- }
2642
-
2643
- // deprecated
2644
- bool llama_kv_self_can_shift(const llama_context * ctx) {
2645
- auto * kv = llama_get_memory(ctx);
2646
- if (!kv) {
2647
- return false;
2648
- }
2649
-
2650
- return llama_memory_can_shift(kv);
2651
- }
2652
-
2653
3304
  // llama state API
2654
3305
 
2655
3306
  // deprecated
@@ -2719,19 +3370,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
2719
3370
  }
2720
3371
 
2721
3372
  size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2722
- return ctx->state_seq_get_size(seq_id);
3373
+ return llama_state_seq_get_size_ext(ctx, seq_id, 0);
2723
3374
  }
2724
3375
 
2725
3376
  size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
3377
+ return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
3378
+ }
3379
+
3380
+ size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
3381
+ return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
3382
+ }
3383
+
3384
+ size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
3385
+ return ctx->state_seq_get_size(seq_id, flags);
3386
+ }
3387
+
3388
+ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2726
3389
  ctx->synchronize();
2727
3390
 
2728
- return ctx->state_seq_get_data(seq_id, dst, size);
3391
+ return ctx->state_seq_get_data(seq_id, dst, size, flags);
2729
3392
  }
2730
3393
 
2731
- size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
3394
+ size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2732
3395
  ctx->synchronize();
2733
3396
 
2734
- return ctx->state_seq_set_data(seq_id, src, size);
3397
+ return ctx->state_seq_set_data(seq_id, src, size, flags);
2735
3398
  }
2736
3399
 
2737
3400
  size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
@@ -2807,12 +3470,149 @@ void llama_perf_context_print(const llama_context * ctx) {
2807
3470
  LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2808
3471
  __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2809
3472
  LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
3473
+ LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
2810
3474
  }
2811
3475
 
2812
3476
  void llama_perf_context_reset(llama_context * ctx) {
2813
3477
  ctx->perf_reset();
2814
3478
  }
2815
3479
 
3480
+ void llama_memory_breakdown_print(const struct llama_context * ctx) {
3481
+ const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
3482
+
3483
+ std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
3484
+
3485
+ std::vector<std::array<std::string, 9>> table_data;
3486
+ table_data.reserve(devices.size());
3487
+ const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n";
3488
+ const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n";
3489
+ const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n";
3490
+
3491
+ table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"});
3492
+
3493
+ constexpr size_t MiB = 1024 * 1024;
3494
+ const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "};
3495
+
3496
+ // track seen buffer types to avoid double counting:
3497
+ std::set<ggml_backend_buffer_type_t> seen_buffer_types;
3498
+
3499
+ // accumulative memory breakdown for each device and for host:
3500
+ std::vector<llama_memory_breakdown_data> mb_dev(devices.size());
3501
+ llama_memory_breakdown_data mb_host;
3502
+
3503
+ for (const auto & buft_mb : memory_breakdown) {
3504
+ ggml_backend_buffer_type_t buft = buft_mb.first;
3505
+ const llama_memory_breakdown_data & mb = buft_mb.second;
3506
+ if (ggml_backend_buft_is_host(buft)) {
3507
+ mb_host.model += mb.model;
3508
+ mb_host.context += mb.context;
3509
+ mb_host.compute += mb.compute;
3510
+ seen_buffer_types.insert(buft);
3511
+ continue;
3512
+ }
3513
+ ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
3514
+ if (dev) {
3515
+ int i_dev = -1;
3516
+ for (size_t i = 0; i < devices.size(); i++) {
3517
+ if (devices[i] == dev) {
3518
+ i_dev = i;
3519
+ break;
3520
+ }
3521
+ }
3522
+ if (i_dev != -1) {
3523
+ mb_dev[i_dev].model += mb.model;
3524
+ mb_dev[i_dev].context += mb.context;
3525
+ mb_dev[i_dev].compute += mb.compute;
3526
+ seen_buffer_types.insert(buft);
3527
+ continue;
3528
+ }
3529
+ }
3530
+ }
3531
+
3532
+ // print memory breakdown for each device:
3533
+ for (size_t i = 0; i < devices.size(); i++) {
3534
+ ggml_backend_dev_t dev = devices[i];
3535
+ llama_memory_breakdown_data mb = mb_dev[i];
3536
+
3537
+ const std::string name = ggml_backend_dev_name(dev);
3538
+ std::string desc = ggml_backend_dev_description(dev);
3539
+ for (const std::string & prefix : desc_prefixes_strip) {
3540
+ if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) {
3541
+ desc = desc.substr(prefix.length());
3542
+ }
3543
+ }
3544
+
3545
+ size_t free, total;
3546
+ ggml_backend_dev_memory(dev, &free, &total);
3547
+
3548
+ const size_t self = mb.model + mb.context + mb.compute;
3549
+ const size_t unaccounted = total - self - free;
3550
+
3551
+ table_data.push_back({
3552
+ template_gpu,
3553
+ " - " + name + " (" + desc + ")",
3554
+ std::to_string(total / MiB),
3555
+ std::to_string(free / MiB),
3556
+ std::to_string(self / MiB),
3557
+ std::to_string(mb.model / MiB),
3558
+ std::to_string(mb.context / MiB),
3559
+ std::to_string(mb.compute / MiB),
3560
+ std::to_string(unaccounted / MiB)});
3561
+ }
3562
+
3563
+ // print memory breakdown for host:
3564
+ {
3565
+ const size_t self = mb_host.model + mb_host.context + mb_host.compute;
3566
+ table_data.push_back({
3567
+ template_other,
3568
+ " - Host",
3569
+ "", // total
3570
+ "", // free
3571
+ std::to_string(self / MiB),
3572
+ std::to_string(mb_host.model / MiB),
3573
+ std::to_string(mb_host.context / MiB),
3574
+ std::to_string(mb_host.compute / MiB),
3575
+ ""}); // unaccounted
3576
+ }
3577
+
3578
+ // print memory breakdown for all remaining buffer types:
3579
+ for (const auto & buft_mb : memory_breakdown) {
3580
+ ggml_backend_buffer_type_t buft = buft_mb.first;
3581
+ const llama_memory_breakdown_data & mb = buft_mb.second;
3582
+ if (seen_buffer_types.count(buft) == 1) {
3583
+ continue;
3584
+ }
3585
+ const std::string name = ggml_backend_buft_name(buft);
3586
+ const size_t self = mb.model + mb.context + mb.compute;
3587
+ table_data.push_back({
3588
+ template_other,
3589
+ " - " + name,
3590
+ "", // total
3591
+ "", // free
3592
+ std::to_string(self / MiB),
3593
+ std::to_string(mb.model / MiB),
3594
+ std::to_string(mb.context / MiB),
3595
+ std::to_string(mb.compute / MiB),
3596
+ ""}); // unaccounted
3597
+ seen_buffer_types.insert(buft);
3598
+ }
3599
+
3600
+ for (size_t j = 1; j < table_data[0].size(); j++) {
3601
+ size_t max_len = 0;
3602
+ for (const auto & td : table_data) {
3603
+ max_len = std::max(max_len, td[j].length());
3604
+ }
3605
+ for (auto & td : table_data) {
3606
+ td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' ');
3607
+ }
3608
+ }
3609
+ for (const auto & td : table_data) {
3610
+ LLAMA_LOG_INFO(td[0].c_str(),
3611
+ __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(),
3612
+ td[6].c_str(), td[7].c_str(), td[8].c_str());
3613
+ }
3614
+ }
3615
+
2816
3616
  //
2817
3617
  // training
2818
3618
  //