whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,5 +1,6 @@
1
1
  #include "llama-context.h"
2
2
 
3
+ #include "llama-arch.h"
3
4
  #include "llama-impl.h"
4
5
  #include "llama-batch.h"
5
6
  #include "llama-io.h"
@@ -8,6 +9,7 @@
8
9
  #include "llama-model.h"
9
10
 
10
11
  #include <cinttypes>
12
+ #include <cmath>
11
13
  #include <cstring>
12
14
  #include <limits>
13
15
  #include <stdexcept>
@@ -21,6 +23,8 @@ llama_context::llama_context(
21
23
  llama_context_params params) :
22
24
  model(model),
23
25
  balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
26
+ // TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
27
+ // may need to be backend-dependent
24
28
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
25
29
 
26
30
  t_start_us = model.t_start_us;
@@ -56,6 +60,25 @@ llama_context::llama_context(
56
60
  cparams.cb_eval = params.cb_eval;
57
61
  cparams.cb_eval_user_data = params.cb_eval_user_data;
58
62
 
63
+ // Initialize backend samplers here so they are part of the sampling graph
64
+ // before the reserve passes run later in this function. This avoids a later
65
+ // re-reserve when graph nodes change.
66
+ if (params.samplers != nullptr && params.n_samplers > 0) {
67
+ for (size_t i = 0; i < params.n_samplers; ++i) {
68
+ const auto & config = params.samplers[i];
69
+
70
+ if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
71
+ throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
72
+ }
73
+
74
+ if (set_sampler(config.seq_id, config.sampler)) {
75
+ const int n_samplers = llama_sampler_chain_n(config.sampler);
76
+
77
+ LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
78
+ }
79
+ }
80
+ }
81
+
59
82
  auto rope_scaling_type = params.rope_scaling_type;
60
83
  if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
61
84
  rope_scaling_type = hparams.rope_scaling_type_train;
@@ -69,6 +92,43 @@ llama_context::llama_context(
69
92
  cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
70
93
  }
71
94
 
95
+ if (cparams.yarn_ext_factor != 0) {
96
+ static auto get_mscale = [](float scale, float mscale) {
97
+ return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
98
+ };
99
+
100
+ const float factor = 1.0f / cparams.rope_freq_scale;
101
+
102
+ // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
103
+ if (hparams.rope_yarn_log_mul != 0.0f) {
104
+ // note: here we assume `mscale == 1.0f`
105
+ // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
106
+ float mscale = 1.0f;
107
+ const float mscale_all_dims = hparams.rope_yarn_log_mul;
108
+
109
+ // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
110
+ // special-case DEEPSEEK v2:
111
+ // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
112
+ if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
113
+ mscale = mscale_all_dims;
114
+ }
115
+
116
+ cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
117
+
118
+ LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
119
+ __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
120
+ } else {
121
+ cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
122
+ }
123
+
124
+ // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
125
+ // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
126
+ //
127
+ // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
128
+ // https://github.com/ggml-org/llama.cpp/pull/17945
129
+ cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
130
+ }
131
+
72
132
  cparams.yarn_attn_factor *= hparams.rope_attn_factor;
73
133
 
74
134
  if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -90,14 +150,6 @@ llama_context::llama_context(
90
150
  // with causal attention, the batch size is limited by the context size
91
151
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
92
152
 
93
- // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
94
- // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
95
- // ref: https://github.com/ggerganov/llama.cpp/pull/5021
96
- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
97
- if (cparams.n_batch < GGML_KQ_MASK_PAD) {
98
- LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
- cparams.n_batch = GGML_KQ_MASK_PAD;
100
- }
101
153
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
102
154
 
103
155
  cparams.op_offload = params.op_offload;
@@ -112,11 +164,28 @@ llama_context::llama_context(
112
164
  }
113
165
  }
114
166
 
115
- const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
167
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
168
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
169
+
170
+ if (cparams.kv_unified) {
171
+ cparams.n_ctx_seq = cparams.n_ctx;
172
+ } else {
173
+ cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
174
+ cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
175
+
176
+ if (cparams.n_ctx_seq == 0) {
177
+ throw std::runtime_error("n_ctx_seq == 0");
178
+ }
179
+
180
+ if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
181
+ cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
182
+ LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
183
+ }
184
+ }
116
185
 
117
186
  LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
118
187
  LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
119
- LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
188
+ LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
120
189
  LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
121
190
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
122
191
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@@ -125,14 +194,14 @@ llama_context::llama_context(
125
194
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
126
195
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
127
196
 
128
- if (n_ctx_per_seq < hparams.n_ctx_train) {
129
- LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
130
- __func__, n_ctx_per_seq, hparams.n_ctx_train);
197
+ if (cparams.n_ctx_seq < hparams.n_ctx_train) {
198
+ LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
199
+ __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
131
200
  }
132
201
 
133
- if (n_ctx_per_seq > hparams.n_ctx_train) {
134
- LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
135
- __func__, n_ctx_per_seq, hparams.n_ctx_train);
202
+ if (cparams.n_ctx_seq > hparams.n_ctx_train) {
203
+ LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
204
+ __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
136
205
  }
137
206
 
138
207
  if (!hparams.vocab_only) {
@@ -181,7 +250,10 @@ llama_context::llama_context(
181
250
  // graph outputs buffer
182
251
  {
183
252
  // resized during inference when a batch uses more outputs
184
- if (output_reserve(params.n_seq_max) < params.n_seq_max) {
253
+ // Create a dummy batch for initialization.
254
+ llama_batch dummy_batch = {};
255
+ dummy_batch.n_tokens = 0;
256
+ if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
185
257
  throw std::runtime_error("failed to reserve initial output buffer");
186
258
  }
187
259
 
@@ -208,6 +280,7 @@ llama_context::llama_context(
208
280
 
209
281
  backend_buft.clear();
210
282
  backend_ptrs.clear();
283
+ backend_buf_exp_size.clear();
211
284
 
212
285
  for (auto & backend : backends) {
213
286
  auto * buft = ggml_backend_get_default_buffer_type(backend.get());
@@ -224,11 +297,15 @@ llama_context::llama_context(
224
297
 
225
298
  backend_buft.push_back(buft);
226
299
  backend_ptrs.push_back(backend.get());
300
+ backend_buf_exp_size.push_back(0);
227
301
  }
228
302
 
229
303
  LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
230
304
 
231
- const size_t max_nodes = this->graph_max_nodes();
305
+ const uint32_t n_seqs = cparams.n_seq_max;
306
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
307
+
308
+ const size_t max_nodes = this->graph_max_nodes(n_tokens);
232
309
 
233
310
  LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
234
311
 
@@ -239,8 +316,8 @@ llama_context::llama_context(
239
316
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
240
317
  bool pipeline_parallel =
241
318
  model.n_devices() > 1 &&
242
- model.params.n_gpu_layers > (int) model.hparams.n_layer &&
243
- model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
319
+ model.n_gpu_layers() > model.hparams.n_layer &&
320
+ model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
244
321
  cparams.offload_kqv &&
245
322
  !model.has_tensor_overrides();
246
323
 
@@ -268,9 +345,7 @@ llama_context::llama_context(
268
345
  if (pipeline_parallel) {
269
346
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
270
347
  }
271
- }
272
348
 
273
- if (!hparams.vocab_only) {
274
349
  llama_memory_context_ptr mctx;
275
350
  if (memory) {
276
351
  LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
@@ -282,9 +357,6 @@ llama_context::llama_context(
282
357
 
283
358
  cross.v_embd.clear();
284
359
 
285
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
286
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
287
-
288
360
  // avoid reserving graphs with zero outputs - assume one output per sequence
289
361
  n_outputs = n_seqs;
290
362
 
@@ -341,9 +413,17 @@ llama_context::llama_context(
341
413
 
342
414
  // reserve pp (prompt processing) graph first so that buffers are only allocated once
343
415
  {
344
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
416
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
417
+ model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
345
418
  if (!gf) {
346
- throw std::runtime_error("failed to allocate compute pp buffers");
419
+ if (pipeline_parallel) {
420
+ LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
421
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
422
+ gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
423
+ }
424
+ if (!gf) {
425
+ throw std::runtime_error("failed to allocate compute pp buffers");
426
+ }
347
427
  }
348
428
 
349
429
  n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
@@ -352,7 +432,7 @@ llama_context::llama_context(
352
432
 
353
433
  // reserve with tg (token generation) graph to get the number of splits and nodes
354
434
  {
355
- auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
435
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
356
436
  if (!gf) {
357
437
  throw std::runtime_error("failed to allocate compute tg buffers");
358
438
  }
@@ -367,7 +447,7 @@ llama_context::llama_context(
367
447
  //
368
448
  // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
369
449
  //
370
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
450
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
371
451
  if (!gf) {
372
452
  throw std::runtime_error("failed to allocate compute pp buffers");
373
453
  }
@@ -376,11 +456,13 @@ llama_context::llama_context(
376
456
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
377
457
  ggml_backend_t backend = backend_ptrs[i];
378
458
  ggml_backend_buffer_type_t buft = backend_buft[i];
379
- size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
380
- if (size > 1) {
459
+ if (!model.hparams.no_alloc) {
460
+ backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
461
+ }
462
+ if (backend_buf_exp_size[i] > 1) {
381
463
  LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
382
464
  ggml_backend_buft_name(buft),
383
- size / 1024.0 / 1024.0);
465
+ backend_buf_exp_size[i] / 1024.0 / 1024.0);
384
466
  }
385
467
  }
386
468
 
@@ -396,9 +478,35 @@ llama_context::llama_context(
396
478
  LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
397
479
  }
398
480
  }
481
+
482
+ // Initialize the full vocabulary token ids for backend samplers.
483
+ {
484
+ const int n_vocab = model.vocab.n_tokens();
485
+
486
+ sampling.token_ids_full_vocab.resize(n_vocab);
487
+ for (int i = 0; i < n_vocab; ++i) {
488
+ sampling.token_ids_full_vocab[i] = i;
489
+ }
490
+ }
399
491
  }
400
492
 
401
493
  llama_context::~llama_context() {
494
+ if (!model.hparams.no_alloc) {
495
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
496
+ ggml_backend_t backend = backend_ptrs[i];
497
+ ggml_backend_buffer_type_t buft = backend_buft[i];
498
+
499
+ const size_t size_exp = backend_buf_exp_size[i];
500
+ const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
501
+ if (size_exp == size_act) {
502
+ LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
503
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
504
+ } else {
505
+ LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
506
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
507
+ }
508
+ }
509
+ }
402
510
  ggml_opt_free(opt_ctx);
403
511
  }
404
512
 
@@ -448,8 +556,8 @@ uint32_t llama_context::n_ctx() const {
448
556
  return cparams.n_ctx;
449
557
  }
450
558
 
451
- uint32_t llama_context::n_ctx_per_seq() const {
452
- return cparams.n_ctx / cparams.n_seq_max;
559
+ uint32_t llama_context::n_ctx_seq() const {
560
+ return cparams.n_ctx_seq;
453
561
  }
454
562
 
455
563
  uint32_t llama_context::n_batch() const {
@@ -518,7 +626,7 @@ bool llama_context::memory_update(bool optimize) {
518
626
  throw std::runtime_error("failed to initialize memory context");
519
627
  }
520
628
 
521
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
629
+ const uint32_t n_seqs = cparams.n_seq_max;
522
630
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
523
631
 
524
632
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -540,6 +648,35 @@ float * llama_context::get_logits() {
540
648
  return logits;
541
649
  }
542
650
 
651
+ int64_t llama_context::output_resolve_row(int32_t i) const {
652
+ int64_t j = -1;
653
+
654
+ // support negative indices (last output row)
655
+ if (i < 0) {
656
+ j = n_outputs + i;
657
+ if (j < 0) {
658
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
659
+ }
660
+ } else if ((size_t) i >= output_ids.size()) {
661
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
662
+ } else {
663
+ // use output_ids to translate the batch token index into a row number
664
+ // that holds this token's data.
665
+ j = output_ids[i];
666
+ }
667
+
668
+ if (j < 0) {
669
+ // the batch token was not configured to output anything
670
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
671
+ }
672
+
673
+ if (j >= n_outputs) {
674
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
675
+ }
676
+
677
+ return j;
678
+ }
679
+
543
680
  float * llama_context::get_logits_ith(int32_t i) {
544
681
  int64_t j = -1;
545
682
 
@@ -550,6 +687,7 @@ float * llama_context::get_logits_ith(int32_t i) {
550
687
  throw std::runtime_error("no logits");
551
688
  }
552
689
 
690
+ // TODO: use output_resolve_row()
553
691
  if (i < 0) {
554
692
  j = n_outputs + i;
555
693
  if (j < 0) {
@@ -586,6 +724,10 @@ float * llama_context::get_embeddings() {
586
724
  return embd;
587
725
  }
588
726
 
727
+ llama_token * llama_context::get_sampled_tokens() const{
728
+ return sampling.sampled;
729
+ }
730
+
589
731
  float * llama_context::get_embeddings_ith(int32_t i) {
590
732
  int64_t j = -1;
591
733
 
@@ -596,6 +738,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
596
738
  throw std::runtime_error("no embeddings");
597
739
  }
598
740
 
741
+ // TODO: use output_resolve_row()
599
742
  if (i < 0) {
600
743
  j = n_outputs + i;
601
744
  if (j < 0) {
@@ -615,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
615
758
  throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
616
759
  }
617
760
 
618
- return embd + j*model.hparams.n_embd;
761
+ const uint32_t n_embd_out = model.hparams.get_n_embd_out();
762
+ return embd + j*n_embd_out;
619
763
  } catch (const std::exception & err) {
620
764
  LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
621
765
  #ifndef NDEBUG
@@ -635,6 +779,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
635
779
  return it->second.data();
636
780
  }
637
781
 
782
+ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
783
+ output_reorder();
784
+
785
+ if (sampling.sampled == nullptr) {
786
+ return LLAMA_TOKEN_NULL;
787
+ }
788
+
789
+ try {
790
+ const int64_t row = output_resolve_row(idx);
791
+ GGML_ASSERT(row < (int64_t) sampling.sampled_size);
792
+ return sampling.sampled[row];
793
+ } catch (const std::exception & err) {
794
+ LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
795
+ return LLAMA_TOKEN_NULL;
796
+ }
797
+ }
798
+
799
+ float * llama_context::get_sampled_probs_ith(int32_t idx) {
800
+ output_reorder();
801
+
802
+ if (sampling.probs == nullptr) {
803
+ return nullptr;
804
+ }
805
+
806
+ try {
807
+ const int64_t row = output_resolve_row(idx);
808
+ if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
809
+ return nullptr;
810
+ }
811
+ return sampling.probs + row*model.vocab.n_tokens();
812
+ } catch (const std::exception & err) {
813
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
814
+ return nullptr;
815
+ }
816
+ }
817
+
818
+ float * llama_context::get_sampled_logits_ith(int32_t idx) {
819
+ output_reorder();
820
+
821
+ if (sampling.logits == nullptr) {
822
+ return nullptr;
823
+ }
824
+
825
+ try {
826
+ const int64_t row = output_resolve_row(idx);
827
+ if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
828
+ return nullptr;
829
+ }
830
+ return sampling.logits + row*model.vocab.n_tokens();
831
+ } catch (const std::exception & err) {
832
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
833
+ return nullptr;
834
+ }
835
+ }
836
+
837
+ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
838
+ output_reorder();
839
+
840
+ try {
841
+ const int64_t row = output_resolve_row(idx);
842
+ if (sampling.candidates != nullptr &&
843
+ (size_t) row < sampling.candidates_count.size() &&
844
+ sampling.candidates_count[row] > 0) {
845
+ return sampling.candidates + row*model.vocab.n_tokens();
846
+ }
847
+ } catch (const std::exception & err) {
848
+ // fallback to full vocab list
849
+ }
850
+
851
+ return sampling.token_ids_full_vocab.data();
852
+ }
853
+
854
+ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
855
+ output_reorder();
856
+
857
+ if (sampling.candidates == nullptr) {
858
+ return 0;
859
+ }
860
+
861
+ try {
862
+ const int64_t row = output_resolve_row(idx);
863
+ if ((size_t) row >= sampling.candidates_count.size()) {
864
+ return 0;
865
+ }
866
+ return sampling.candidates_count[row];
867
+ } catch (const std::exception & err) {
868
+ LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
869
+ return 0;
870
+ }
871
+ }
872
+
873
+ size_t llama_context::get_sampled_logits_count(int32_t idx) {
874
+ output_reorder();
875
+
876
+ if (sampling.logits == nullptr) {
877
+ return model.vocab.n_tokens();
878
+ }
879
+
880
+ try {
881
+ const int64_t row = output_resolve_row(idx);
882
+ if ((size_t) row >= sampling.logits_count.size()) {
883
+ return 0;
884
+ }
885
+ return sampling.logits_count[row];
886
+ } catch (const std::exception & err) {
887
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
888
+ return 0;
889
+ }
890
+ }
891
+
892
+ size_t llama_context::get_sampled_probs_count(int32_t idx) {
893
+ output_reorder();
894
+
895
+ if (sampling.probs == nullptr) {
896
+ return 0;
897
+ }
898
+
899
+ try {
900
+ const int64_t row = output_resolve_row(idx);
901
+ if ((size_t) row >= sampling.probs_count.size()) {
902
+ return 0;
903
+ }
904
+ return sampling.probs_count[row];
905
+ } catch (const std::exception & err) {
906
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
907
+ return 0;
908
+ }
909
+ }
910
+
911
+
638
912
  void llama_context::attach_threadpool(
639
913
  ggml_threadpool_t threadpool,
640
914
  ggml_threadpool_t threadpool_batch) {
@@ -691,6 +965,42 @@ void llama_context::set_warmup(bool value) {
691
965
  cparams.warmup = value;
692
966
  }
693
967
 
968
+ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
969
+ LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
970
+
971
+ const bool can_offload =
972
+ sampler &&
973
+ sampler->iface->backend_init &&
974
+ sampler->iface->backend_apply &&
975
+ llama_sampler_chain_n(sampler) > 0;
976
+
977
+ if (sampler && can_offload) {
978
+ ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
979
+ auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
980
+ if (host_buft) {
981
+ buft = host_buft;
982
+ }
983
+
984
+ sampler->iface->backend_init(sampler, buft);
985
+
986
+ sampling.samplers[seq_id] = sampler;
987
+
988
+ return true;
989
+ }
990
+
991
+ if (sampler && !can_offload) {
992
+ LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
993
+
994
+ sampling.samplers.erase(seq_id);
995
+
996
+ return false;
997
+ }
998
+
999
+ sampling.samplers.erase(seq_id);
1000
+
1001
+ return true;
1002
+ }
1003
+
694
1004
  void llama_context::set_adapter_lora(
695
1005
  llama_adapter_lora * adapter,
696
1006
  float scale) {
@@ -803,7 +1113,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
803
1113
 
804
1114
  const auto & hparams = model.hparams;
805
1115
 
806
- const int64_t n_embd = hparams.n_embd;
1116
+ const int64_t n_embd = hparams.n_embd_inp();
807
1117
  const int64_t n_vocab = model.vocab.n_tokens();
808
1118
 
809
1119
  // note: during encode, we always pass the full sequence starting from pos = 0
@@ -831,7 +1141,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
831
1141
  n_queued_tokens += n_tokens;
832
1142
 
833
1143
  // reserve output buffer
834
- if (output_reserve(n_tokens) < n_tokens) {
1144
+ if (output_reserve(n_tokens, batch_inp) < n_tokens) {
835
1145
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
836
1146
  return -2;
837
1147
  };
@@ -885,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
885
1195
  {
886
1196
  // extract token embeddings
887
1197
  GGML_ASSERT(embd != nullptr);
1198
+ const uint32_t n_embd_out = hparams.get_n_embd_out();
888
1199
 
889
- GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
890
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1200
+ GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
1201
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
891
1202
  } break;
892
1203
  case LLAMA_POOLING_TYPE_MEAN:
893
1204
  case LLAMA_POOLING_TYPE_CLS:
@@ -955,6 +1266,112 @@ int llama_context::encode(const llama_batch & batch_inp) {
955
1266
  return 0;
956
1267
  }
957
1268
 
1269
+ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
1270
+ std::map<llama_seq_id, uint32_t> seq_to_row;
1271
+ // how many output tokens we have seen so far for this ubatch.
1272
+ uint32_t local = 0;
1273
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1274
+ // skip tokens that are not output.
1275
+ if (!ubatch.output[i]) {
1276
+ continue;
1277
+ }
1278
+
1279
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
1280
+ // row_offset is the number of output tokens before this ubatch.
1281
+ seq_to_row[seq_id] = row_offset + local;
1282
+ ++local;
1283
+ }
1284
+ return seq_to_row;
1285
+ }
1286
+
1287
+ static void copy_tensor_async_ints(
1288
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1289
+ llama_token * sampled,
1290
+ size_t sampled_size,
1291
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1292
+ ggml_backend_sched_t sched) {
1293
+ if (sampled == nullptr) {
1294
+ return;
1295
+ }
1296
+
1297
+ for (const auto & [seq_id, tensor] : tensor_map) {
1298
+ auto it = seq_to_row.find(seq_id);
1299
+ if (it == seq_to_row.end()) {
1300
+ continue;
1301
+ }
1302
+
1303
+ const uint32_t row = it->second;
1304
+ GGML_ASSERT(row < sampled_size);
1305
+
1306
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
1307
+
1308
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1309
+ ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
1310
+ }
1311
+ }
1312
+
1313
+ static void copy_tensor_async_floats(
1314
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1315
+ float * dst,
1316
+ size_t stride,
1317
+ std::vector<uint32_t> & counts,
1318
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1319
+ ggml_backend_sched_t sched) {
1320
+ if (dst == nullptr) {
1321
+ return;
1322
+ }
1323
+
1324
+ for (const auto & [seq_id, tensor] : tensor_map) {
1325
+ auto it = seq_to_row.find(seq_id);
1326
+ if (it == seq_to_row.end()) {
1327
+ continue;
1328
+ }
1329
+
1330
+ const uint32_t row = it->second;
1331
+ GGML_ASSERT(row < counts.size());
1332
+
1333
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
1334
+
1335
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1336
+ float * row_ptr = dst + (size_t) row * stride;
1337
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1338
+
1339
+ // Update the actual number of logits/probabilities that were written for this row.
1340
+ counts[row] = ggml_nelements(tensor);
1341
+ }
1342
+ }
1343
+
1344
+ static void copy_tensor_async_candidates(
1345
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1346
+ llama_token * dst,
1347
+ size_t stride,
1348
+ std::vector<uint32_t> & counts,
1349
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1350
+ ggml_backend_sched_t sched) {
1351
+ if (dst == nullptr) {
1352
+ return;
1353
+ }
1354
+
1355
+ for (const auto & [seq_id, tensor] : tensor_map) {
1356
+ auto it = seq_to_row.find(seq_id);
1357
+ if (it == seq_to_row.end()) {
1358
+ continue;
1359
+ }
1360
+
1361
+ const uint32_t row = it->second;
1362
+ GGML_ASSERT(row < counts.size());
1363
+
1364
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
1365
+
1366
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1367
+ llama_token * row_ptr = dst + (size_t) row * stride;
1368
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1369
+
1370
+ // Update the actual number of candidates that were written.
1371
+ counts[row] = ggml_nelements(tensor);
1372
+ }
1373
+ }
1374
+
958
1375
  int llama_context::decode(const llama_batch & batch_inp) {
959
1376
  GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
960
1377
 
@@ -972,12 +1389,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
972
1389
  const auto & hparams = model.hparams;
973
1390
 
974
1391
  const int64_t n_vocab = vocab.n_tokens();
975
- const int64_t n_embd = hparams.n_embd;
1392
+ const int64_t n_embd = hparams.n_embd_inp();
976
1393
 
977
1394
  // when computing embeddings, all tokens are output
978
- const bool output_all = cparams.embeddings;
1395
+ const bool output_all = cparams.embeddings;
1396
+ const bool has_samplers = !sampling.samplers.empty();
1397
+
1398
+ const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
1399
+
1400
+ // TODO: avoid this workaround in the future
1401
+ if (has_samplers && batch_inp.logits) {
1402
+ std::vector<int32_t> seq_output_count(n_seq_max, 0);
1403
+
1404
+ for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
1405
+ if (batch_inp.logits[i] == 0) {
1406
+ continue;
1407
+ }
979
1408
 
980
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
1409
+ const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
1410
+
1411
+ for (int32_t s = 0; s < ns; ++s) {
1412
+ const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
1413
+
1414
+ seq_output_count[seq_id]++;
1415
+ if (seq_output_count[seq_id] > 1) {
1416
+ LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
1417
+ __func__, seq_id, seq_output_count[seq_id]);
1418
+ return -1;
1419
+ }
1420
+ }
1421
+ }
1422
+ }
1423
+
1424
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
981
1425
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
982
1426
  return -1;
983
1427
  }
@@ -1058,7 +1502,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1058
1502
  }
1059
1503
 
1060
1504
  // reserve output buffer
1061
- if (output_reserve(n_outputs_all) < n_outputs_all) {
1505
+ if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
1062
1506
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1063
1507
  return -2;
1064
1508
  };
@@ -1131,7 +1575,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
1131
1575
  }
1132
1576
 
1133
1577
  // extract logits
1134
- if (t_logits && n_outputs > 0) {
1578
+ // For multi-sequence batches that mix backend samplers and CPU sampler
1579
+ // this is currently inefficient as we copy all logits even for the
1580
+ // backend sampled tokens.
1581
+ if (logits && t_logits && n_outputs > 0) {
1135
1582
  ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1136
1583
  GGML_ASSERT(backend_res != nullptr);
1137
1584
  GGML_ASSERT(logits != nullptr);
@@ -1146,7 +1593,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1146
1593
  }
1147
1594
 
1148
1595
  // extract embeddings
1149
- if (t_embd && n_outputs > 0) {
1596
+ if (embd && t_embd && n_outputs > 0) {
1150
1597
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1151
1598
  GGML_ASSERT(backend_embd != nullptr);
1152
1599
 
@@ -1155,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
1155
1602
  {
1156
1603
  // extract token embeddings
1157
1604
  GGML_ASSERT(embd != nullptr);
1158
- float * embd_out = embd + n_outputs_prev*n_embd;
1605
+ const uint32_t n_embd_out = hparams.get_n_embd_out();
1606
+ float * embd_out = embd + n_outputs_prev*n_embd_out;
1159
1607
 
1160
1608
  if (n_outputs) {
1161
1609
  GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1162
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1163
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
1610
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
1611
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
1164
1612
  }
1165
1613
  } break;
1166
1614
  case LLAMA_POOLING_TYPE_MEAN:
@@ -1200,6 +1648,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
1200
1648
  }
1201
1649
  }
1202
1650
 
1651
+ // This flag indicates whether a backend sampler has actually sampled a specific
1652
+ // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
1653
+ const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
1654
+
1655
+ if (has_samplers && has_sampled) {
1656
+ const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
1657
+ const auto stride = n_vocab;
1658
+
1659
+ // async copy the sampling data from the backend to the host
1660
+ copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
1661
+
1662
+ copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
1663
+ copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
1664
+ copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
1665
+ }
1666
+
1203
1667
  n_outputs_prev += n_outputs;
1204
1668
  } while (mctx->next());
1205
1669
 
@@ -1224,7 +1688,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1224
1688
 
1225
1689
  // make the outputs have the same order they had in the user-provided batch
1226
1690
  // note: this is mostly relevant for recurrent models atm
1227
- if (!sorted_output) {
1691
+ if (!sorted_output && n_outputs > 1) {
1228
1692
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1229
1693
 
1230
1694
  // TODO: is there something more efficient which also minimizes swaps?
@@ -1263,15 +1727,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
1263
1727
  // output
1264
1728
  //
1265
1729
 
1266
- uint32_t llama_context::output_reserve(int32_t n_outputs) {
1730
+ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
1267
1731
  const auto & hparams = model.hparams;
1268
1732
  const auto & vocab = model.vocab;
1269
1733
 
1270
1734
  const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
1271
1735
 
1272
- const auto n_batch = cparams.n_batch;
1273
- const auto n_vocab = vocab.n_tokens();
1274
- const auto n_embd = hparams.n_embd;
1736
+ const auto n_batch = cparams.n_batch;
1737
+ const auto n_vocab = vocab.n_tokens();
1738
+ const auto n_embd_out = hparams.get_n_embd_out();
1275
1739
 
1276
1740
  bool has_logits = true;
1277
1741
  bool has_embd = cparams.embeddings;
@@ -1282,8 +1746,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1282
1746
  has_embd = true;
1283
1747
  }
1284
1748
 
1285
- logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1286
- embd_size = has_embd ? n_embd*n_outputs_max : 0;
1749
+ // Check which sampling modes are needed for the current batch.
1750
+ // TODO: avoid this branching by working with the worst-case
1751
+ bool has_sampling = false;
1752
+ bool cpu_logits = false;
1753
+
1754
+ if (batch.logits) {
1755
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
1756
+ if (!batch.logits[i]) {
1757
+ continue;
1758
+ }
1759
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
1760
+ llama_seq_id seq_id = batch.seq_id[i][j];
1761
+ if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
1762
+ has_sampling = true;
1763
+ } else {
1764
+ cpu_logits = true;
1765
+ }
1766
+ }
1767
+ }
1768
+ } else {
1769
+ // When batch.logits is nullptr (when loading state with a dummy batch),
1770
+ // allocate CPU logits.
1771
+ cpu_logits = true;
1772
+ }
1773
+
1774
+ size_t backend_float_count = 0;
1775
+ size_t backend_token_count = 0;
1776
+
1777
+ // Allocate CPU logits buffer only if needed by sequences in this batch
1778
+ logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
1779
+ embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
1780
+
1781
+ // TODO: avoid this branching by working with the worst-case
1782
+ if (!has_sampling) {
1783
+ sampling.logits_size = 0;
1784
+ sampling.probs_size = 0;
1785
+ sampling.sampled_size = 0;
1786
+ sampling.candidates_size = 0;
1787
+ } else {
1788
+ sampling.logits_size = n_vocab*n_outputs_max;
1789
+ sampling.probs_size = n_vocab*n_outputs_max;
1790
+ sampling.sampled_size = n_outputs_max;
1791
+ sampling.candidates_size = n_vocab*n_outputs_max;
1792
+
1793
+ backend_float_count = sampling.logits_size + sampling.probs_size;
1794
+ backend_token_count = sampling.sampled_size + sampling.candidates_size;
1795
+ }
1287
1796
 
1288
1797
  if (output_ids.empty()) {
1289
1798
  // init, never resized afterwards
@@ -1291,7 +1800,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1291
1800
  }
1292
1801
 
1293
1802
  const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
1294
- const size_t new_size = (logits_size + embd_size) * sizeof(float);
1803
+ const size_t new_size =
1804
+ (logits_size + embd_size + backend_float_count) * sizeof(float) +
1805
+ ( backend_token_count) * sizeof(llama_token);
1295
1806
 
1296
1807
  // alloc only when more than the current capacity is required
1297
1808
  // TODO: also consider shrinking the buffer
@@ -1299,8 +1810,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1299
1810
  if (buf_output) {
1300
1811
  #ifndef NDEBUG
1301
1812
  // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1302
- LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1813
+ LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1303
1814
  #endif
1815
+ synchronize();
1816
+
1817
+ // TODO: not needed?
1304
1818
  buf_output = nullptr;
1305
1819
  logits = nullptr;
1306
1820
  embd = nullptr;
@@ -1322,8 +1836,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1322
1836
 
1323
1837
  float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
1324
1838
 
1325
- logits = has_logits ? output_base : nullptr;
1326
- embd = has_embd ? output_base + logits_size : nullptr;
1839
+ logits = nullptr;
1840
+ embd = nullptr;
1841
+
1842
+ size_t offset = 0;
1843
+ uint8_t * base = (uint8_t *) output_base;
1844
+
1845
+ logits = (has_logits && cpu_logits) ? output_base : nullptr;
1846
+ offset += logits_size * sizeof(float);
1847
+
1848
+ embd = has_embd ? (float *) (base + offset) : nullptr;
1849
+ offset += embd_size * sizeof(float);
1850
+
1851
+ sampling.logits = nullptr;
1852
+ sampling.probs = nullptr;
1853
+ sampling.sampled = nullptr;
1854
+ sampling.candidates = nullptr;
1855
+
1856
+ if (has_sampling) {
1857
+ sampling.logits = (float *) (base + offset);
1858
+ offset += sampling.logits_size * sizeof(float);
1859
+
1860
+ sampling.probs = (float *) (base + offset);
1861
+ offset += sampling.probs_size * sizeof(float);
1862
+
1863
+ sampling.sampled = (llama_token *) (base + offset);
1864
+ offset += sampling.sampled_size * sizeof(llama_token);
1865
+
1866
+ sampling.candidates = (llama_token *) (base + offset);
1867
+ offset += sampling.candidates_size * sizeof(llama_token);
1868
+
1869
+ // The count vectors keep track of the actual number of logits/probs/candidates
1870
+ // copied from the backend for each output row.
1871
+
1872
+ sampling.logits_count.resize(n_outputs_max);
1873
+ sampling.probs_count.resize(n_outputs_max);
1874
+ sampling.candidates_count.resize(n_outputs_max);
1875
+
1876
+ std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
1877
+ std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
1878
+ std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
1879
+
1880
+ std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
1881
+ }
1327
1882
 
1328
1883
  // set all ids as invalid (negative)
1329
1884
  std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1352,6 +1907,40 @@ void llama_context::output_reorder() {
1352
1907
  std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1353
1908
  }
1354
1909
  }
1910
+
1911
+ if (sampling.logits && sampling.logits_size > 0) {
1912
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1913
+ std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
1914
+ }
1915
+ }
1916
+
1917
+ if (sampling.probs && sampling.probs_size > 0) {
1918
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1919
+ std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
1920
+ }
1921
+ }
1922
+
1923
+ if (sampling.candidates && sampling.candidates_size > 0) {
1924
+ for (uint64_t k = 0; k < n_vocab; ++k) {
1925
+ std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
1926
+ }
1927
+ }
1928
+
1929
+ if (sampling.sampled && sampling.sampled_size > 0) {
1930
+ std::swap(sampling.sampled[i0], sampling.sampled[i1]);
1931
+ }
1932
+
1933
+ if (!sampling.logits_count.empty()) {
1934
+ std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
1935
+ }
1936
+
1937
+ if (!sampling.probs_count.empty()) {
1938
+ std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
1939
+ }
1940
+
1941
+ if (!sampling.candidates_count.empty()) {
1942
+ std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
1943
+ }
1355
1944
  }
1356
1945
 
1357
1946
  output_swaps.clear();
@@ -1361,21 +1950,27 @@ void llama_context::output_reorder() {
1361
1950
  // graph
1362
1951
  //
1363
1952
 
1364
- uint32_t llama_context::graph_max_nodes() const {
1365
- return std::max<uint32_t>(1024u, 8u*model.n_tensors());
1953
+ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
1954
+ if (model.arch == LLM_ARCH_QWEN3NEXT) {
1955
+ return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
1956
+ }
1957
+ uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
1958
+ res += model.n_lora_nodes;
1959
+ return res;
1366
1960
  }
1367
1961
 
1368
1962
  llm_graph_result * llama_context::get_gf_res_reserve() const {
1369
1963
  return static_cast<llm_graph_result *>(gf_res_reserve.get());
1370
1964
  }
1371
1965
 
1372
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
1966
+ ggml_cgraph * llama_context::graph_reserve(
1967
+ uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
1373
1968
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1374
1969
  GGML_ASSERT(n_outputs >= 1);
1375
1970
 
1376
1971
  if (n_tokens % n_seqs != 0) {
1377
1972
  n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1378
- n_outputs = std::min(n_outputs, n_tokens);
1973
+ n_outputs = std::max(n_outputs, n_tokens);
1379
1974
 
1380
1975
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1381
1976
  }
@@ -1394,6 +1989,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1394
1989
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1395
1990
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1396
1991
 
1992
+ // set one output token per sequence in order to activate all backend samplers
1993
+ std::vector<llama_seq_id> seq_ids(n_seqs);
1994
+ for (uint32_t i = 0; i < n_seqs; ++i) {
1995
+ seq_ids[i] = i;
1996
+ ubatch.n_seq_id[i] = 1;
1997
+ ubatch.seq_id[i] = &seq_ids[i];
1998
+ ubatch.output[i] = true;
1999
+ }
2000
+
1397
2001
  auto * res = gf_res_reserve.get();
1398
2002
 
1399
2003
  const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
@@ -1406,8 +2010,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1406
2010
 
1407
2011
  // initialize scheduler with the specified graph
1408
2012
  if (split_only) {
1409
- ggml_backend_sched_split_graph(sched.get(), gf);
2013
+ if (sizes) {
2014
+ ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
2015
+ } else {
2016
+ ggml_backend_sched_split_graph(sched.get(), gf);
2017
+ }
1410
2018
  } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
2019
+ GGML_ASSERT(!sizes);
1411
2020
  LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1412
2021
  return nullptr;
1413
2022
  }
@@ -1419,7 +2028,7 @@ llm_graph_params llama_context::graph_params(
1419
2028
  llm_graph_result * res,
1420
2029
  const llama_ubatch & ubatch,
1421
2030
  const llama_memory_context_i * mctx,
1422
- llm_graph_type gtype) const {
2031
+ llm_graph_type gtype) const {
1423
2032
  return {
1424
2033
  /*.arch =*/ model.arch,
1425
2034
  /*.hparams =*/ model.hparams,
@@ -1432,6 +2041,7 @@ llm_graph_params llama_context::graph_params(
1432
2041
  /*.loras =*/ &loras,
1433
2042
  /*.mctx =*/ mctx,
1434
2043
  /*.cross =*/ &cross,
2044
+ /*.samplers =*/ sampling.samplers,
1435
2045
  /*.n_outputs =*/ n_outputs,
1436
2046
  /*.cb =*/ graph_get_cb(),
1437
2047
  /*.res =*/ res,
@@ -1484,7 +2094,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
1484
2094
 
1485
2095
  // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1486
2096
  // FIXME: fix in ggml_backend_sched
1487
- const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
2097
+ const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
1488
2098
  if (ubatch.n_tokens < 32 || full_offload) {
1489
2099
  if (il != -1 && strcmp(name, "norm") == 0) {
1490
2100
  const auto & dev_layer = model.dev_layer(il);
@@ -1887,6 +2497,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1887
2497
  }
1888
2498
  }
1889
2499
 
2500
+ // TODO: handle sampling buffers and samplers state ?
2501
+ // https://github.com/ggml-org/llama.cpp/pull/17004
2502
+
1890
2503
  if (memory != nullptr) {
1891
2504
  LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1892
2505
  memory->state_write(io);
@@ -1919,7 +2532,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1919
2532
  auto n_outputs = this->n_outputs;
1920
2533
  io.read_to(&n_outputs, sizeof(n_outputs));
1921
2534
 
1922
- if (n_outputs > output_reserve(n_outputs)) {
2535
+ // Create a dummy batch for state loading.
2536
+ llama_batch dummy_batch = {};
2537
+ dummy_batch.n_tokens = 0;
2538
+ if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
1923
2539
  throw std::runtime_error("could not reserve outputs");
1924
2540
  }
1925
2541
 
@@ -1973,6 +2589,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1973
2589
  }
1974
2590
  }
1975
2591
 
2592
+ // TODO: handle sampling buffers and samplers state ?
2593
+ // https://github.com/ggml-org/llama.cpp/pull/17004
2594
+
1976
2595
  if (memory) {
1977
2596
  LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
1978
2597
 
@@ -2029,15 +2648,26 @@ void llama_context::perf_reset() {
2029
2648
 
2030
2649
  std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
2031
2650
  std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
2032
- for (const auto & buft_size : model.memory_breakdown()) {
2033
- ret[buft_size.first].model += buft_size.second;
2651
+ for (const auto & [buft, size] : model.memory_breakdown()) {
2652
+ ret[buft].model += size;
2034
2653
  }
2035
- for (const auto & buft_size : memory->memory_breakdown()) {
2036
- ret[buft_size.first].context += buft_size.second;
2654
+ if (memory) {
2655
+ for (const auto & [buft, size] : memory->memory_breakdown()) {
2656
+ ret[buft].context += size;
2657
+ }
2037
2658
  }
2038
- for (const auto & backend_ptr : backends) {
2039
- ggml_backend_t backend = backend_ptr.get();
2040
- ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
2659
+ if (model.hparams.no_alloc) {
2660
+ for (size_t i = 0; i < backends.size(); ++i) {
2661
+ ggml_backend_t backend = backends[i].get();
2662
+ ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
2663
+ ret[buft].compute += backend_buf_exp_size[i];
2664
+ }
2665
+ } else {
2666
+ for (const auto & backend_ptr : backends) {
2667
+ ggml_backend_t backend = backend_ptr.get();
2668
+ ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
2669
+ ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
2670
+ }
2041
2671
  }
2042
2672
  return ret;
2043
2673
  }
@@ -2130,7 +2760,7 @@ void llama_context::opt_epoch_iter(
2130
2760
  batch.logits [pos_batch] = true;
2131
2761
  }
2132
2762
 
2133
- if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2763
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2134
2764
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2135
2765
  return;
2136
2766
  }
@@ -2150,7 +2780,7 @@ void llama_context::opt_epoch_iter(
2150
2780
  }
2151
2781
 
2152
2782
  // reserve output buffer
2153
- if (output_reserve(n_outputs_all) < n_outputs_all) {
2783
+ if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
2154
2784
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
2155
2785
  GGML_ABORT("TODO: handle this error");
2156
2786
  };
@@ -2295,6 +2925,8 @@ llama_context_params llama_context_default_params() {
2295
2925
  /*.op_offload =*/ true,
2296
2926
  /*.swa_full =*/ true,
2297
2927
  /*.kv_unified =*/ false,
2928
+ /*.sampler =*/ nullptr,
2929
+ /*.n_sampler =*/ 0,
2298
2930
  };
2299
2931
 
2300
2932
  return result;
@@ -2346,6 +2978,13 @@ llama_context * llama_init_from_model(
2346
2978
  return nullptr;
2347
2979
  }
2348
2980
 
2981
+ if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
2982
+ params.pooling_type != model->hparams.pooling_type) {
2983
+ //user-specified pooling-type is different from the model default
2984
+ LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
2985
+ model->hparams.pooling_type, params.pooling_type);
2986
+ }
2987
+
2349
2988
  try {
2350
2989
  auto * ctx = new llama_context(*model, params);
2351
2990
  return ctx;
@@ -2371,6 +3010,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
2371
3010
  return ctx->n_ctx();
2372
3011
  }
2373
3012
 
3013
+ uint32_t llama_n_ctx_seq(const llama_context * ctx) {
3014
+ return ctx->n_ctx_seq();
3015
+ }
3016
+
2374
3017
  uint32_t llama_n_batch(const llama_context * ctx) {
2375
3018
  return ctx->n_batch();
2376
3019
  }
@@ -2443,7 +3086,15 @@ float * llama_get_logits(llama_context * ctx) {
2443
3086
  float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
2444
3087
  ctx->synchronize();
2445
3088
 
2446
- return ctx->get_logits_ith(i);
3089
+ float * res = nullptr;
3090
+
3091
+ res = ctx->get_sampled_logits_ith(i);
3092
+
3093
+ if (!res) {
3094
+ res = ctx->get_logits_ith(i);
3095
+ }
3096
+
3097
+ return res;
2447
3098
  }
2448
3099
 
2449
3100
  float * llama_get_embeddings(llama_context * ctx) {
@@ -2464,6 +3115,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
2464
3115
  return ctx->get_embeddings_seq(seq_id);
2465
3116
  }
2466
3117
 
3118
+ bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
3119
+ return ctx->set_sampler(seq_id, smpl);
3120
+ }
3121
+
3122
+ llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
3123
+ ctx->synchronize();
3124
+
3125
+ return ctx->get_sampled_token_ith(i);
3126
+ }
3127
+
3128
+ float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
3129
+ ctx->synchronize();
3130
+
3131
+ return ctx->get_sampled_probs_ith(i);
3132
+ }
3133
+
3134
+ float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
3135
+ ctx->synchronize();
3136
+
3137
+ return ctx->get_sampled_logits_ith(i);
3138
+ }
3139
+
3140
+ llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
3141
+ ctx->synchronize();
3142
+
3143
+ return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
3144
+ }
3145
+
3146
+ uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
3147
+ ctx->synchronize();
3148
+
3149
+ return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
3150
+ }
3151
+
3152
+ uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
3153
+ ctx->synchronize();
3154
+
3155
+ return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
3156
+ }
3157
+
3158
+ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
3159
+ ctx->synchronize();
3160
+
3161
+ return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
3162
+ }
3163
+
2467
3164
  // llama adapter API
2468
3165
 
2469
3166
  int32_t llama_set_adapter_lora(