whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,13 +1,16 @@
1
1
  #include "llama-context.h"
2
2
 
3
+ #include "llama-arch.h"
3
4
  #include "llama-impl.h"
4
5
  #include "llama-batch.h"
5
6
  #include "llama-io.h"
6
7
  #include "llama-memory.h"
7
8
  #include "llama-mmap.h"
8
9
  #include "llama-model.h"
10
+ #include "llama-ext.h"
9
11
 
10
12
  #include <cinttypes>
13
+ #include <cmath>
11
14
  #include <cstring>
12
15
  #include <limits>
13
16
  #include <stdexcept>
@@ -20,7 +23,11 @@ llama_context::llama_context(
20
23
  const llama_model & model,
21
24
  llama_context_params params) :
22
25
  model(model),
26
+ cvec(std::make_unique<llama_adapter_cvec>()),
27
+ loras(std::make_unique<llama_adapter_loras>()),
23
28
  balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
29
+ // TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
30
+ // may need to be backend-dependent
24
31
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
25
32
 
26
33
  t_start_us = model.t_start_us;
@@ -56,6 +63,25 @@ llama_context::llama_context(
56
63
  cparams.cb_eval = params.cb_eval;
57
64
  cparams.cb_eval_user_data = params.cb_eval_user_data;
58
65
 
66
+ // Initialize backend samplers here so they are part of the sampling graph
67
+ // before the reserve passes run later in this function. This avoids a later
68
+ // re-reserve when graph nodes change.
69
+ if (params.samplers != nullptr && params.n_samplers > 0) {
70
+ for (size_t i = 0; i < params.n_samplers; ++i) {
71
+ const auto & config = params.samplers[i];
72
+
73
+ if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
74
+ throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
75
+ }
76
+
77
+ if (set_sampler(config.seq_id, config.sampler)) {
78
+ const int n_samplers = llama_sampler_chain_n(config.sampler);
79
+
80
+ LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
81
+ }
82
+ }
83
+ }
84
+
59
85
  auto rope_scaling_type = params.rope_scaling_type;
60
86
  if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
61
87
  rope_scaling_type = hparams.rope_scaling_type_train;
@@ -69,6 +95,43 @@ llama_context::llama_context(
69
95
  cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
70
96
  }
71
97
 
98
+ if (cparams.yarn_ext_factor != 0) {
99
+ static auto get_mscale = [](float scale, float mscale) {
100
+ return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
101
+ };
102
+
103
+ const float factor = 1.0f / cparams.rope_freq_scale;
104
+
105
+ // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
106
+ if (hparams.rope_yarn_log_mul != 0.0f) {
107
+ // note: here we assume `mscale == 1.0f`
108
+ // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
109
+ float mscale = 1.0f;
110
+ const float mscale_all_dims = hparams.rope_yarn_log_mul;
111
+
112
+ // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
113
+ // special-case DEEPSEEK v2:
114
+ // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
115
+ if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
116
+ mscale = mscale_all_dims;
117
+ }
118
+
119
+ cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
120
+
121
+ LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
122
+ __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
123
+ } else {
124
+ cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
125
+ }
126
+
127
+ // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
128
+ // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
129
+ //
130
+ // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
131
+ // https://github.com/ggml-org/llama.cpp/pull/17945
132
+ cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
133
+ }
134
+
72
135
  cparams.yarn_attn_factor *= hparams.rope_attn_factor;
73
136
 
74
137
  if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -86,23 +149,23 @@ llama_context::llama_context(
86
149
  }
87
150
 
88
151
  cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
152
+ cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
153
+
154
+ cparams.fused_gdn_ar = true;
155
+ cparams.fused_gdn_ch = true;
156
+ cparams.auto_fgdn = true;
89
157
 
90
158
  // with causal attention, the batch size is limited by the context size
91
159
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
92
160
 
93
- // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
94
- // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
95
- // ref: https://github.com/ggerganov/llama.cpp/pull/5021
96
- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
97
- if (cparams.n_batch < GGML_KQ_MASK_PAD) {
98
- LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
- cparams.n_batch = GGML_KQ_MASK_PAD;
100
- }
101
161
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
102
162
 
103
163
  cparams.op_offload = params.op_offload;
104
164
  cparams.kv_unified = params.kv_unified;
105
165
 
166
+ // initialized later
167
+ cparams.pipeline_parallel = false;
168
+
106
169
  {
107
170
  const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
108
171
  graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -112,11 +175,28 @@ llama_context::llama_context(
112
175
  }
113
176
  }
114
177
 
115
- const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
178
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
179
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
180
+
181
+ if (cparams.kv_unified) {
182
+ cparams.n_ctx_seq = cparams.n_ctx;
183
+ } else {
184
+ cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
185
+ cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
186
+
187
+ if (cparams.n_ctx_seq == 0) {
188
+ throw std::runtime_error("n_ctx_seq == 0");
189
+ }
190
+
191
+ if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
192
+ cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
193
+ LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
194
+ }
195
+ }
116
196
 
117
197
  LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
118
198
  LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
119
- LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
199
+ LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
120
200
  LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
121
201
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
122
202
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@@ -125,14 +205,14 @@ llama_context::llama_context(
125
205
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
126
206
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
127
207
 
128
- if (n_ctx_per_seq < hparams.n_ctx_train) {
129
- LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
130
- __func__, n_ctx_per_seq, hparams.n_ctx_train);
208
+ if (cparams.n_ctx_seq < hparams.n_ctx_train) {
209
+ LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
210
+ __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
131
211
  }
132
212
 
133
- if (n_ctx_per_seq > hparams.n_ctx_train) {
134
- LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
135
- __func__, n_ctx_per_seq, hparams.n_ctx_train);
213
+ if (cparams.n_ctx_seq > hparams.n_ctx_train) {
214
+ LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
215
+ __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
136
216
  }
137
217
 
138
218
  if (!hparams.vocab_only) {
@@ -180,7 +260,6 @@ llama_context::llama_context(
180
260
 
181
261
  // graph outputs buffer
182
262
  {
183
- // resized during inference when a batch uses more outputs
184
263
  if (output_reserve(params.n_seq_max) < params.n_seq_max) {
185
264
  throw std::runtime_error("failed to reserve initial output buffer");
186
265
  }
@@ -208,6 +287,7 @@ llama_context::llama_context(
208
287
 
209
288
  backend_buft.clear();
210
289
  backend_ptrs.clear();
290
+ backend_buf_exp_size.clear();
211
291
 
212
292
  for (auto & backend : backends) {
213
293
  auto * buft = ggml_backend_get_default_buffer_type(backend.get());
@@ -224,23 +304,17 @@ llama_context::llama_context(
224
304
 
225
305
  backend_buft.push_back(buft);
226
306
  backend_ptrs.push_back(backend.get());
307
+ backend_buf_exp_size.push_back(0);
227
308
  }
228
309
 
229
310
  LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
230
311
 
231
- const size_t max_nodes = this->graph_max_nodes();
232
-
233
- LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
234
-
235
- gf_res_prev.reset(new llm_graph_result(max_nodes));
236
- gf_res_reserve.reset(new llm_graph_result(max_nodes));
237
-
238
312
  // TODO: move these checks to ggml_backend_sched
239
313
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
240
314
  bool pipeline_parallel =
241
315
  model.n_devices() > 1 &&
242
- model.params.n_gpu_layers > (int) model.hparams.n_layer &&
243
- model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
316
+ model.n_gpu_layers() > model.hparams.n_layer &&
317
+ model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
244
318
  cparams.offload_kqv &&
245
319
  !model.has_tensor_overrides();
246
320
 
@@ -250,6 +324,7 @@ llama_context::llama_context(
250
324
  auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
251
325
  if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
252
326
  // ignore CPU backend
327
+ // TODO: should we ignore ACCEL types too?
253
328
  continue;
254
329
  }
255
330
  auto * dev = ggml_backend_get_device(backend.get());
@@ -263,146 +338,308 @@ llama_context::llama_context(
263
338
  }
264
339
  }
265
340
 
266
- sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
341
+ cparams.pipeline_parallel = pipeline_parallel;
267
342
 
268
- if (pipeline_parallel) {
269
- LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
343
+ if (cparams.pipeline_parallel) {
344
+ LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__);
345
+
346
+ if (!graph_reuse_disable) {
347
+ // TODO: figure out a way to make graph reuse work with pipeline parallelism
348
+ // ref: https://github.com/ggml-org/llama.cpp/pull/20463
349
+ LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__);
350
+
351
+ graph_reuse_disable = true;
352
+ }
353
+ }
354
+
355
+ sched_reserve();
356
+
357
+ if (!cparams.flash_attn) {
358
+ if (ggml_is_quantized(params.type_v)) {
359
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
360
+ }
270
361
  }
271
362
  }
272
363
 
273
- if (!hparams.vocab_only) {
274
- llama_memory_context_ptr mctx;
275
- if (memory) {
276
- LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
277
- mctx = memory->init_full();
278
- if (!mctx) {
279
- throw std::runtime_error("failed to initialize memory module");
364
+ // Initialize the full vocabulary token ids for backend samplers.
365
+ {
366
+ const int n_vocab = model.vocab.n_tokens();
367
+
368
+ sampling.token_ids_full_vocab.resize(n_vocab);
369
+ for (int i = 0; i < n_vocab; ++i) {
370
+ sampling.token_ids_full_vocab[i] = i;
371
+ }
372
+ }
373
+ }
374
+
375
+ llama_context::~llama_context() {
376
+ if (!model.hparams.no_alloc) {
377
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
378
+ ggml_backend_t backend = backend_ptrs[i];
379
+ ggml_backend_buffer_type_t buft = backend_buft[i];
380
+
381
+ const size_t size_exp = backend_buf_exp_size[i];
382
+ const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
383
+ if (size_exp == size_act) {
384
+ LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
385
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
386
+ } else {
387
+ LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
388
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
280
389
  }
281
390
  }
391
+ }
392
+ ggml_opt_free(opt_ctx);
393
+ }
394
+
395
+ void llama_context::sched_reserve() {
396
+ if (!sched_need_reserve) {
397
+ return;
398
+ }
399
+
400
+ sched_need_reserve = false;
401
+
402
+ LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
403
+
404
+ synchronize();
405
+
406
+ const int64_t t_start_us = ggml_time_us();
407
+
408
+ const uint32_t n_seqs = cparams.n_seq_max;
409
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
410
+
411
+ const size_t max_nodes = this->graph_max_nodes(n_tokens);
412
+
413
+ LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
414
+
415
+ gf_res_prev.reset(new llm_graph_result(max_nodes));
416
+ gf_res_reserve.reset(new llm_graph_result(max_nodes));
417
+
418
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload));
419
+
420
+ llama_memory_context_ptr mctx;
421
+ if (memory) {
422
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
423
+ mctx = memory->init_full();
424
+ if (!mctx) {
425
+ throw std::runtime_error("failed to initialize memory module");
426
+ }
427
+ }
282
428
 
283
- cross.v_embd.clear();
429
+ // avoid reserving graphs with zero outputs - assume one output per sequence
430
+ const int n_outputs = n_seqs;
284
431
 
285
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
286
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
432
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
287
433
 
288
- // avoid reserving graphs with zero outputs - assume one output per sequence
289
- n_outputs = n_seqs;
434
+ // resolve automatic Flash Attention use
435
+ if (cparams.auto_fa) {
436
+ auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
437
+ if (!gf) {
438
+ throw std::runtime_error("failed to reserve graph for Flash Attention check");
439
+ }
290
440
 
291
- LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
441
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
442
+ bool fa_device_mismatch = false;
443
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
444
+ ggml_tensor * n = ggml_graph_node(gf, i);
445
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
446
+ continue;
447
+ }
448
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
449
+
450
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
451
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
452
+ const int il = std::stoi(n->name + prefix_len);
453
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
454
+ if (device_fa != device_kv) {
455
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
456
+ "is assigned to device %s (usually due to missing support)\n",
457
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
458
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
459
+ fa_device_mismatch = true;
460
+ break;
461
+ }
462
+ }
292
463
 
293
- // resolve automatic Flash Attention use
294
- if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
464
+ if (fa_device_mismatch) {
465
+ cparams.flash_attn = false;
466
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
467
+ } else {
468
+ cparams.flash_attn = true;
469
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
470
+ }
471
+
472
+ cparams.auto_fa = false;
473
+ }
474
+
475
+ if (cparams.auto_fgdn) {
476
+ LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
477
+
478
+ if (cparams.fused_gdn_ar) {
295
479
  auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
296
480
  if (!gf) {
297
- throw std::runtime_error("failed to split graph for Flash Attention check");
481
+ throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
298
482
  }
299
483
 
300
- const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
301
- bool fa_device_mismatch = false;
484
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
485
+ bool gdn_device_mismatch = false;
302
486
  for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
303
487
  ggml_tensor * n = ggml_graph_node(gf, i);
304
- if (n->op != GGML_OP_FLASH_ATTN_EXT) {
488
+ if (n->op != GGML_OP_GATED_DELTA_NET) {
305
489
  continue;
306
490
  }
307
- ggml_backend_dev_t device_fa = ggml_backend_get_device(
308
- ggml_backend_sched_get_tensor_backend(sched.get(), n));
491
+ ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
309
492
 
310
- // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
311
- GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
493
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
312
494
  const int il = std::stoi(n->name + prefix_len);
313
495
  ggml_backend_dev_t device_kv = model.dev_layer(il);
314
- if (device_fa != device_kv) {
315
- LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
316
- "is assigned to device %s (usually due to missing support)\n",
317
- __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
318
- // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
319
- fa_device_mismatch = true;
496
+ if (device_gdn != device_kv) {
497
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
498
+ "is assigned to device %s (usually due to missing support)\n",
499
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
500
+ gdn_device_mismatch = true;
320
501
  break;
321
502
  }
322
503
  }
323
- if (fa_device_mismatch) {
324
- cparams.flash_attn = false;
325
- LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
326
- if (ggml_is_quantized(params.type_v)) {
327
- throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
328
- }
504
+
505
+ if (gdn_device_mismatch) {
506
+ cparams.fused_gdn_ar = false;
507
+ LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
329
508
  } else {
330
- cparams.flash_attn = true;
331
- LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
509
+ LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
332
510
  }
333
511
  }
334
512
 
335
- // reserve worst-case graph
336
- int n_splits_pp = -1;
337
- int n_nodes_pp = -1;
338
-
339
- int n_splits_tg = -1;
340
- int n_nodes_tg = -1;
341
-
342
- // reserve pp (prompt processing) graph first so that buffers are only allocated once
343
- {
344
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
513
+ if (cparams.fused_gdn_ch) {
514
+ // more than one token in the batch per sequence in order to take the chunked path
515
+ // note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
516
+ // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
517
+ // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
518
+ // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
519
+ const uint32_t n_tokens_ch = 16*n_seqs;
520
+ auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
345
521
  if (!gf) {
346
- throw std::runtime_error("failed to allocate compute pp buffers");
522
+ throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
347
523
  }
348
524
 
349
- n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
350
- n_nodes_pp = ggml_graph_n_nodes(gf);
351
- }
525
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
526
+ bool gdn_device_mismatch = false;
527
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
528
+ ggml_tensor * n = ggml_graph_node(gf, i);
529
+ if (n->op != GGML_OP_GATED_DELTA_NET) {
530
+ continue;
531
+ }
532
+ ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
352
533
 
353
- // reserve with tg (token generation) graph to get the number of splits and nodes
354
- {
355
- auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
356
- if (!gf) {
357
- throw std::runtime_error("failed to allocate compute tg buffers");
534
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
535
+ const int il = std::stoi(n->name + prefix_len);
536
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
537
+ if (device_gdn != device_kv) {
538
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
539
+ "is assigned to device %s (usually due to missing support)\n",
540
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
541
+ gdn_device_mismatch = true;
542
+ break;
543
+ }
358
544
  }
359
545
 
360
- n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
361
- n_nodes_tg = ggml_graph_n_nodes(gf);
546
+ if (gdn_device_mismatch) {
547
+ cparams.fused_gdn_ch = false;
548
+ LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
549
+ } else {
550
+ LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
551
+ }
362
552
  }
363
553
 
364
- // reserve again with pp graph to avoid ggml-alloc reallocations during inference
365
- {
366
- // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
367
- //
368
- // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
369
- //
370
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
554
+ cparams.auto_fgdn = false;
555
+ }
556
+
557
+ // reserve worst-case graph
558
+ int n_splits_pp = -1;
559
+ int n_nodes_pp = -1;
560
+
561
+ int n_splits_tg = -1;
562
+ int n_nodes_tg = -1;
563
+
564
+ // reserve pp (prompt processing) graph first so that buffers are only allocated once
565
+ {
566
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
567
+ model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
568
+ if (!gf) {
569
+ if (cparams.pipeline_parallel) {
570
+ LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
571
+ cparams.pipeline_parallel = false;
572
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
573
+ gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
574
+ }
371
575
  if (!gf) {
372
576
  throw std::runtime_error("failed to allocate compute pp buffers");
373
577
  }
374
578
  }
375
579
 
376
- for (size_t i = 0; i < backend_ptrs.size(); ++i) {
377
- ggml_backend_t backend = backend_ptrs[i];
378
- ggml_backend_buffer_type_t buft = backend_buft[i];
379
- size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
380
- if (size > 1) {
381
- LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
382
- ggml_backend_buft_name(buft),
383
- size / 1024.0 / 1024.0);
384
- }
580
+ n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
581
+ n_nodes_pp = ggml_graph_n_nodes(gf);
582
+ }
583
+
584
+ // reserve with tg (token generation) graph to get the number of splits and nodes
585
+ {
586
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
587
+ if (!gf) {
588
+ throw std::runtime_error("failed to allocate compute tg buffers");
385
589
  }
386
590
 
387
- if (n_nodes_pp == n_nodes_tg) {
388
- LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
389
- } else {
390
- LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
591
+ n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
592
+ n_nodes_tg = ggml_graph_n_nodes(gf);
593
+ }
594
+
595
+ // reserve again with pp graph to avoid ggml-alloc reallocations during inference
596
+ {
597
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
598
+ //
599
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
600
+ //
601
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
602
+ if (!gf) {
603
+ throw std::runtime_error("failed to allocate compute pp buffers");
391
604
  }
605
+ }
392
606
 
393
- if (n_splits_pp == n_splits_tg) {
394
- LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
395
- } else {
396
- LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
607
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
608
+ ggml_backend_t backend = backend_ptrs[i];
609
+ ggml_backend_buffer_type_t buft = backend_buft[i];
610
+ if (!model.hparams.no_alloc) {
611
+ backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
612
+ }
613
+ if (backend_buf_exp_size[i] > 1) {
614
+ LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
615
+ ggml_backend_buft_name(buft),
616
+ backend_buf_exp_size[i] / 1024.0 / 1024.0);
397
617
  }
398
618
  }
399
- }
400
619
 
401
- llama_context::~llama_context() {
402
- ggml_opt_free(opt_ctx);
620
+ if (n_nodes_pp == n_nodes_tg) {
621
+ LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
622
+ } else {
623
+ LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
624
+ }
625
+
626
+ if (n_splits_pp == n_splits_tg) {
627
+ LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
628
+ } else {
629
+ LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
630
+ }
631
+
632
+ const int64_t t_end_us = ggml_time_us();
633
+
634
+ LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n",
635
+ __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get()));
403
636
  }
404
637
 
405
638
  void llama_context::synchronize() {
639
+ if (!sched) {
640
+ return;
641
+ }
642
+
406
643
  ggml_backend_sched_synchronize(sched.get());
407
644
 
408
645
  // FIXME: if multiple single tokens are evaluated without a synchronization,
@@ -448,8 +685,8 @@ uint32_t llama_context::n_ctx() const {
448
685
  return cparams.n_ctx;
449
686
  }
450
687
 
451
- uint32_t llama_context::n_ctx_per_seq() const {
452
- return cparams.n_ctx / cparams.n_seq_max;
688
+ uint32_t llama_context::n_ctx_seq() const {
689
+ return cparams.n_ctx_seq;
453
690
  }
454
691
 
455
692
  uint32_t llama_context::n_batch() const {
@@ -518,7 +755,7 @@ bool llama_context::memory_update(bool optimize) {
518
755
  throw std::runtime_error("failed to initialize memory context");
519
756
  }
520
757
 
521
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
758
+ const uint32_t n_seqs = cparams.n_seq_max;
522
759
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
523
760
 
524
761
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -537,39 +774,48 @@ enum llama_pooling_type llama_context::pooling_type() const {
537
774
  float * llama_context::get_logits() {
538
775
  output_reorder();
539
776
 
540
- return logits;
777
+ return logits.data;
541
778
  }
542
779
 
543
- float * llama_context::get_logits_ith(int32_t i) {
780
+ int64_t llama_context::output_resolve_row(int32_t i) const {
544
781
  int64_t j = -1;
545
782
 
783
+ // support negative indices (last output row)
784
+ if (i < 0) {
785
+ j = n_outputs + i;
786
+ if (j < 0) {
787
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
788
+ }
789
+ } else if ((size_t) i >= output_ids.size()) {
790
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
791
+ } else {
792
+ // use output_ids to translate the batch token index into a row number
793
+ // that holds this token's data.
794
+ j = output_ids[i];
795
+ }
796
+
797
+ if (j < 0) {
798
+ // the batch token was not configured to output anything
799
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
800
+ }
801
+
802
+ if (j >= n_outputs) {
803
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
804
+ }
805
+
806
+ return j;
807
+ }
808
+
809
+ float * llama_context::get_logits_ith(int32_t i) {
546
810
  output_reorder();
547
811
 
548
812
  try {
549
- if (logits == nullptr) {
813
+ if (logits.data == nullptr) {
550
814
  throw std::runtime_error("no logits");
551
815
  }
552
816
 
553
- if (i < 0) {
554
- j = n_outputs + i;
555
- if (j < 0) {
556
- throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
557
- }
558
- } else if ((size_t) i >= output_ids.size()) {
559
- throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
560
- } else {
561
- j = output_ids[i];
562
- }
563
-
564
- if (j < 0) {
565
- throw std::runtime_error(format("batch.logits[%d] != true", i));
566
- }
567
- if (j >= n_outputs) {
568
- // This should not happen
569
- throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
570
- }
571
-
572
- return logits + j*model.vocab.n_tokens();
817
+ const int64_t j = output_resolve_row(i);
818
+ return logits.data + j*model.vocab.n_tokens();
573
819
  } catch (const std::exception & err) {
574
820
  LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
575
821
  #ifndef NDEBUG
@@ -583,39 +829,24 @@ float * llama_context::get_logits_ith(int32_t i) {
583
829
  float * llama_context::get_embeddings() {
584
830
  output_reorder();
585
831
 
586
- return embd;
832
+ return embd.data;
587
833
  }
588
834
 
589
- float * llama_context::get_embeddings_ith(int32_t i) {
590
- int64_t j = -1;
835
+ llama_token * llama_context::get_sampled_tokens() const{
836
+ return sampling.sampled.data;
837
+ }
591
838
 
839
+ float * llama_context::get_embeddings_ith(int32_t i) {
592
840
  output_reorder();
593
841
 
594
842
  try {
595
- if (embd == nullptr) {
843
+ if (embd.data == nullptr) {
596
844
  throw std::runtime_error("no embeddings");
597
845
  }
598
846
 
599
- if (i < 0) {
600
- j = n_outputs + i;
601
- if (j < 0) {
602
- throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
603
- }
604
- } else if ((size_t) i >= output_ids.size()) {
605
- throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
606
- } else {
607
- j = output_ids[i];
608
- }
609
-
610
- if (j < 0) {
611
- throw std::runtime_error(format("batch.logits[%d] != true", i));
612
- }
613
- if (j >= n_outputs) {
614
- // This should not happen
615
- throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
616
- }
617
-
618
- return embd + j*model.hparams.n_embd;
847
+ const int64_t j = output_resolve_row(i);
848
+ const uint32_t n_embd_out = model.hparams.n_embd_out();
849
+ return embd.data + j*n_embd_out;
619
850
  } catch (const std::exception & err) {
620
851
  LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
621
852
  #ifndef NDEBUG
@@ -635,6 +866,137 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
635
866
  return it->second.data();
636
867
  }
637
868
 
869
+ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
870
+ output_reorder();
871
+
872
+ if (!sampling.sampled.has_data()) {
873
+ return LLAMA_TOKEN_NULL;
874
+ }
875
+
876
+ try {
877
+ const int64_t row = output_resolve_row(idx);
878
+ GGML_ASSERT(row < (int64_t) sampling.sampled.size);
879
+ return sampling.sampled.data[row];
880
+ } catch (const std::exception & err) {
881
+ LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
882
+ return LLAMA_TOKEN_NULL;
883
+ }
884
+ }
885
+
886
+ float * llama_context::get_sampled_probs_ith(int32_t idx) {
887
+ output_reorder();
888
+
889
+ if (!sampling.probs.has_data()) {
890
+ return nullptr;
891
+ }
892
+
893
+ try {
894
+ const int64_t row = output_resolve_row(idx);
895
+ if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
896
+ return nullptr;
897
+ }
898
+ return sampling.probs.data + row*model.vocab.n_tokens();
899
+ } catch (const std::exception & err) {
900
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
901
+ return nullptr;
902
+ }
903
+ }
904
+
905
+ float * llama_context::get_sampled_logits_ith(int32_t idx) {
906
+ output_reorder();
907
+
908
+ if (!sampling.logits.has_data()) {
909
+ return nullptr;
910
+ }
911
+
912
+ try {
913
+ const int64_t row = output_resolve_row(idx);
914
+ if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
915
+ return nullptr;
916
+ }
917
+ return sampling.logits.data + row*model.vocab.n_tokens();
918
+ } catch (const std::exception & err) {
919
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
920
+ return nullptr;
921
+ }
922
+ }
923
+
924
+ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
925
+ output_reorder();
926
+
927
+ try {
928
+ const int64_t row = output_resolve_row(idx);
929
+ if (sampling.candidates.has_data() &&
930
+ (size_t) row < sampling.candidates_count.size() &&
931
+ sampling.candidates_count[row] > 0) {
932
+ return sampling.candidates.data + row*model.vocab.n_tokens();
933
+ }
934
+ } catch (const std::exception & err) {
935
+ // fallback to full vocab list
936
+ GGML_UNUSED(err);
937
+ }
938
+
939
+ return sampling.token_ids_full_vocab.data();
940
+ }
941
+
942
+ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
943
+ output_reorder();
944
+
945
+ if (!sampling.candidates.has_data()) {
946
+ return 0;
947
+ }
948
+
949
+ try {
950
+ const int64_t row = output_resolve_row(idx);
951
+ if ((size_t) row >= sampling.candidates_count.size()) {
952
+ return 0;
953
+ }
954
+ return sampling.candidates_count[row];
955
+ } catch (const std::exception & err) {
956
+ LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
957
+ return 0;
958
+ }
959
+ }
960
+
961
+ size_t llama_context::get_sampled_logits_count(int32_t idx) {
962
+ output_reorder();
963
+
964
+ if (!sampling.logits.has_data()) {
965
+ return model.vocab.n_tokens();
966
+ }
967
+
968
+ try {
969
+ const int64_t row = output_resolve_row(idx);
970
+ if ((size_t) row >= sampling.logits_count.size()) {
971
+ return 0;
972
+ }
973
+ return sampling.logits_count[row];
974
+ } catch (const std::exception & err) {
975
+ LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
976
+ return 0;
977
+ }
978
+ }
979
+
980
+ size_t llama_context::get_sampled_probs_count(int32_t idx) {
981
+ output_reorder();
982
+
983
+ if (!sampling.probs.has_data()) {
984
+ return 0;
985
+ }
986
+
987
+ try {
988
+ const int64_t row = output_resolve_row(idx);
989
+ if ((size_t) row >= sampling.probs_count.size()) {
990
+ return 0;
991
+ }
992
+ return sampling.probs_count[row];
993
+ } catch (const std::exception & err) {
994
+ LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
995
+ return 0;
996
+ }
997
+ }
998
+
999
+
638
1000
  void llama_context::attach_threadpool(
639
1001
  ggml_threadpool_t threadpool,
640
1002
  ggml_threadpool_t threadpool_batch) {
@@ -671,54 +1033,131 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void
671
1033
  set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
672
1034
  }
673
1035
  }
674
- }
1036
+ }
1037
+
1038
+ void llama_context::set_embeddings(bool value) {
1039
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1040
+
1041
+ cparams.embeddings = value;
1042
+
1043
+ // TODO: not sure yet if we want to reserve here
1044
+ //sched_need_reserve = true;
1045
+ }
1046
+
1047
+ void llama_context::set_causal_attn(bool value) {
1048
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1049
+
1050
+ if (cparams.causal_attn == value) {
1051
+ return;
1052
+ }
1053
+
1054
+ cparams.causal_attn = value;
1055
+
1056
+ sched_need_reserve = true;
1057
+ }
1058
+
1059
+ void llama_context::set_warmup(bool value) {
1060
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1061
+
1062
+ if (cparams.warmup == value) {
1063
+ return;
1064
+ }
1065
+
1066
+ cparams.warmup = value;
1067
+
1068
+ // warmups are usually with small batches, so no need to reserve
1069
+ //sched_need_reserve = true;
1070
+ }
1071
+
1072
+ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
1073
+ if (!sampler && sampling.samplers.count(seq_id) == 0) {
1074
+ return true;
1075
+ }
1076
+
1077
+ LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
1078
+
1079
+ const bool can_offload =
1080
+ sampler &&
1081
+ sampler->iface->backend_init &&
1082
+ sampler->iface->backend_apply &&
1083
+ llama_sampler_chain_n(sampler) > 0;
1084
+
1085
+ if (sampler && can_offload) {
1086
+ auto * buft = ggml_backend_dev_buffer_type(model.dev_output());
1087
+
1088
+ sampler->iface->backend_init(sampler, buft);
1089
+
1090
+ sampling.samplers[seq_id] = sampler;
1091
+
1092
+ sched_need_reserve = true;
1093
+
1094
+ return true;
1095
+ }
675
1096
 
676
- void llama_context::set_embeddings(bool value) {
677
- LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1097
+ if (sampler && !can_offload) {
1098
+ LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
678
1099
 
679
- cparams.embeddings = value;
680
- }
1100
+ if (sampling.samplers.count(seq_id) > 0) {
1101
+ sched_need_reserve = true;
1102
+ }
681
1103
 
682
- void llama_context::set_causal_attn(bool value) {
683
- LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1104
+ sampling.samplers.erase(seq_id);
684
1105
 
685
- cparams.causal_attn = value;
686
- }
1106
+ return false;
1107
+ }
687
1108
 
688
- void llama_context::set_warmup(bool value) {
689
- LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1109
+ sampling.samplers.erase(seq_id);
690
1110
 
691
- cparams.warmup = value;
1111
+ sched_need_reserve = true;
1112
+
1113
+ return true;
692
1114
  }
693
1115
 
694
- void llama_context::set_adapter_lora(
695
- llama_adapter_lora * adapter,
696
- float scale) {
697
- LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
1116
+ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
1117
+ LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
698
1118
 
699
- loras[adapter] = scale;
700
- }
1119
+ if (adapters_lora_are_same(adapters, n_adapters, scales)) {
1120
+ return;
1121
+ }
701
1122
 
702
- bool llama_context::rm_adapter_lora(
703
- llama_adapter_lora * adapter) {
704
- LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
1123
+ loras.reset(new llama_adapter_loras());
705
1124
 
706
- auto pos = loras.find(adapter);
707
- if (pos != loras.end()) {
708
- loras.erase(pos);
709
- return true;
1125
+ for (size_t i = 0; i < n_adapters; i ++) {
1126
+ if (scales[i] != 0.0f) {
1127
+ loras->insert({adapters[i], scales[i]});
1128
+ }
710
1129
  }
711
1130
 
712
- return false;
1131
+ sched_need_reserve = true;
713
1132
  }
714
1133
 
715
- void llama_context::clear_adapter_lora() {
716
- LLAMA_LOG_DEBUG("%s: call\n", __func__);
1134
+ bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
1135
+ LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
1136
+
1137
+ // Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison.
1138
+ size_t n_non_zero = 0;
1139
+
1140
+ for (size_t i = 0; i < n_adapters; i ++) {
1141
+ if (scales[i] == 0.0f) {
1142
+ continue;
1143
+ }
1144
+ n_non_zero++;
1145
+
1146
+ auto it = loras->find(adapters[i]);
1147
+
1148
+ if (it == loras->end() || it->second != scales[i]) {
1149
+ return false;
1150
+ }
1151
+ }
1152
+
1153
+ if (n_non_zero != loras->size()) {
1154
+ return false;
1155
+ }
717
1156
 
718
- loras.clear();
1157
+ return true;
719
1158
  }
720
1159
 
721
- bool llama_context::apply_adapter_cvec(
1160
+ bool llama_context::set_adapter_cvec(
722
1161
  const float * data,
723
1162
  size_t len,
724
1163
  int32_t n_embd,
@@ -726,7 +1165,9 @@ bool llama_context::apply_adapter_cvec(
726
1165
  int32_t il_end) {
727
1166
  LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
728
1167
 
729
- return cvec.apply(model, data, len, n_embd, il_start, il_end);
1168
+ // TODO: should we reserve?
1169
+
1170
+ return cvec->apply(model, data, len, n_embd, il_start, il_end);
730
1171
  }
731
1172
 
732
1173
  llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
@@ -776,6 +1217,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
776
1217
  {
777
1218
  //const auto t_start_us = ggml_time_us();
778
1219
 
1220
+ // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated
779
1221
  res->set_inputs(&ubatch);
780
1222
 
781
1223
  //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
@@ -803,7 +1245,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
803
1245
 
804
1246
  const auto & hparams = model.hparams;
805
1247
 
806
- const int64_t n_embd = hparams.n_embd;
1248
+ const int64_t n_embd = hparams.n_embd_inp();
807
1249
  const int64_t n_vocab = model.vocab.n_tokens();
808
1250
 
809
1251
  // note: during encode, we always pass the full sequence starting from pos = 0
@@ -828,6 +1270,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
828
1270
  // TODO: this clear of the buffer can easily be forgotten - need something better
829
1271
  embd_seq.clear();
830
1272
 
1273
+ sched_reserve();
1274
+
831
1275
  n_queued_tokens += n_tokens;
832
1276
 
833
1277
  // reserve output buffer
@@ -867,16 +1311,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
867
1311
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
868
1312
 
869
1313
  // extract logits
870
- if (logits && t_logits) {
1314
+ if (logits.data && t_logits) {
871
1315
  ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
872
1316
  GGML_ASSERT(backend_res != nullptr);
873
- GGML_ASSERT(logits != nullptr);
1317
+ GGML_ASSERT(logits.data != nullptr);
874
1318
 
875
- ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
1319
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float));
876
1320
  }
877
1321
 
878
1322
  // extract embeddings
879
- if (embd && t_embd) {
1323
+ if (embd.data && t_embd) {
880
1324
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
881
1325
  GGML_ASSERT(backend_embd != nullptr);
882
1326
 
@@ -884,10 +1328,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
884
1328
  case LLAMA_POOLING_TYPE_NONE:
885
1329
  {
886
1330
  // extract token embeddings
887
- GGML_ASSERT(embd != nullptr);
1331
+ GGML_ASSERT(embd.data != nullptr);
1332
+ const uint32_t n_embd_out = hparams.n_embd_out();
888
1333
 
889
- GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
890
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1334
+ GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size);
1335
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float));
891
1336
  } break;
892
1337
  case LLAMA_POOLING_TYPE_MEAN:
893
1338
  case LLAMA_POOLING_TYPE_CLS:
@@ -935,7 +1380,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
935
1380
  cross.n_embd = t_embd->ne[0];
936
1381
  cross.n_enc = t_embd->ne[1];
937
1382
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
938
- memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
1383
+ memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd));
939
1384
 
940
1385
  const auto & batch = balloc->get_batch();
941
1386
 
@@ -955,6 +1400,128 @@ int llama_context::encode(const llama_batch & batch_inp) {
955
1400
  return 0;
956
1401
  }
957
1402
 
1403
+ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
1404
+ std::map<llama_seq_id, uint32_t> seq_to_row;
1405
+ // how many output tokens we have seen so far for this ubatch.
1406
+ uint32_t local = 0;
1407
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1408
+ // skip tokens that are not output.
1409
+ if (!ubatch.output[i]) {
1410
+ continue;
1411
+ }
1412
+
1413
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
1414
+ // row_offset is the number of output tokens before this ubatch.
1415
+ seq_to_row[seq_id] = row_offset + local;
1416
+ ++local;
1417
+ }
1418
+ return seq_to_row;
1419
+ }
1420
+
1421
+ static void copy_tensor_async_ints(
1422
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1423
+ const buffer_view<llama_token> & sampled,
1424
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1425
+ ggml_backend_sched_t sched) {
1426
+ if (!sampled.has_data()) {
1427
+ return;
1428
+ }
1429
+
1430
+ for (const auto & [seq_id, tensor] : tensor_map) {
1431
+ auto it = seq_to_row.find(seq_id);
1432
+ if (it == seq_to_row.end()) {
1433
+ continue;
1434
+ }
1435
+
1436
+ const uint32_t row = it->second;
1437
+ GGML_ASSERT(row < sampled.size);
1438
+
1439
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
1440
+
1441
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1442
+ ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
1443
+ }
1444
+ }
1445
+
1446
+ static void copy_tensor_async_floats(
1447
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1448
+ const buffer_view<float> & dst,
1449
+ size_t stride,
1450
+ std::vector<uint32_t> & counts,
1451
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1452
+ ggml_backend_sched_t sched) {
1453
+ if (!dst.has_data()) {
1454
+ return;
1455
+ }
1456
+
1457
+ for (const auto & [seq_id, tensor] : tensor_map) {
1458
+ auto it = seq_to_row.find(seq_id);
1459
+ if (it == seq_to_row.end()) {
1460
+ continue;
1461
+ }
1462
+
1463
+ const uint32_t row = it->second;
1464
+ GGML_ASSERT(row < counts.size());
1465
+
1466
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
1467
+
1468
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1469
+ float * row_ptr = dst.data + (size_t) row * stride;
1470
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1471
+
1472
+ // Update the actual number of logits/probabilities that were written for this row.
1473
+ counts[row] = ggml_nelements(tensor);
1474
+ }
1475
+ }
1476
+
1477
+ static void copy_tensor_async_candidates(
1478
+ const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1479
+ const buffer_view<llama_token> & dst,
1480
+ size_t stride,
1481
+ std::vector<uint32_t> & counts,
1482
+ const std::map<llama_seq_id, uint32_t> & seq_to_row,
1483
+ ggml_backend_sched_t sched) {
1484
+ if (!dst.has_data()) {
1485
+ return;
1486
+ }
1487
+
1488
+ for (const auto & [seq_id, tensor] : tensor_map) {
1489
+ auto it = seq_to_row.find(seq_id);
1490
+ if (it == seq_to_row.end()) {
1491
+ continue;
1492
+ }
1493
+
1494
+ const uint32_t row = it->second;
1495
+ GGML_ASSERT(row < counts.size());
1496
+
1497
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
1498
+
1499
+ ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1500
+ llama_token * row_ptr = dst.data + (size_t) row * stride;
1501
+ ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1502
+
1503
+ // Update the actual number of candidates that were written.
1504
+ counts[row] = ggml_nelements(tensor);
1505
+ }
1506
+ }
1507
+
1508
+ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
1509
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1510
+ if (!ubatch.output[i]) {
1511
+ continue;
1512
+ }
1513
+
1514
+ // Check if the output token has at least one sequence without a backend sampler.
1515
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
1516
+ llama_seq_id seq_id = ubatch.seq_id[i][j];
1517
+ if (samplers.find(seq_id) == samplers.end()) {
1518
+ return true;
1519
+ }
1520
+ }
1521
+ }
1522
+ return false; // all sequences use backend sampling
1523
+ }
1524
+
958
1525
  int llama_context::decode(const llama_batch & batch_inp) {
959
1526
  GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
960
1527
 
@@ -972,12 +1539,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
972
1539
  const auto & hparams = model.hparams;
973
1540
 
974
1541
  const int64_t n_vocab = vocab.n_tokens();
975
- const int64_t n_embd = hparams.n_embd;
1542
+ const int64_t n_embd = hparams.n_embd_inp();
976
1543
 
977
1544
  // when computing embeddings, all tokens are output
978
- const bool output_all = cparams.embeddings;
1545
+ const bool output_all = cparams.embeddings;
1546
+ const bool has_samplers = !sampling.samplers.empty();
1547
+
1548
+ const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
1549
+
1550
+ // TODO: avoid this workaround in the future
1551
+ if (has_samplers && batch_inp.logits) {
1552
+ std::vector<int32_t> seq_output_count(n_seq_max, 0);
1553
+
1554
+ for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
1555
+ if (batch_inp.logits[i] == 0) {
1556
+ continue;
1557
+ }
1558
+
1559
+ const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
1560
+
1561
+ for (int32_t s = 0; s < ns; ++s) {
1562
+ const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
1563
+
1564
+ seq_output_count[seq_id]++;
1565
+ if (seq_output_count[seq_id] > 1) {
1566
+ LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
1567
+ __func__, seq_id, seq_output_count[seq_id]);
1568
+ return -1;
1569
+ }
1570
+ }
1571
+ }
1572
+ }
979
1573
 
980
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
1574
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
981
1575
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
982
1576
  return -1;
983
1577
  }
@@ -1007,6 +1601,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
1007
1601
  embd_seq.clear();
1008
1602
  output_swaps.clear();
1009
1603
 
1604
+ sched_reserve();
1605
+
1010
1606
  bool did_optimize = false;
1011
1607
 
1012
1608
  // handle any pending shifts/copies
@@ -1131,22 +1727,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
1131
1727
  }
1132
1728
 
1133
1729
  // extract logits
1134
- if (t_logits && n_outputs > 0) {
1730
+ if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
1135
1731
  ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1136
1732
  GGML_ASSERT(backend_res != nullptr);
1137
- GGML_ASSERT(logits != nullptr);
1733
+ GGML_ASSERT(logits.data != nullptr);
1138
1734
 
1139
- float * logits_out = logits + n_outputs_prev*n_vocab;
1735
+ float * logits_out = logits.data + n_outputs_prev*n_vocab;
1140
1736
 
1141
1737
  if (n_outputs) {
1142
1738
  GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1143
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1739
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size);
1144
1740
  ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1145
1741
  }
1146
1742
  }
1147
1743
 
1148
1744
  // extract embeddings
1149
- if (t_embd && n_outputs > 0) {
1745
+ if (embd.data && t_embd && n_outputs > 0) {
1150
1746
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1151
1747
  GGML_ASSERT(backend_embd != nullptr);
1152
1748
 
@@ -1154,13 +1750,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
1154
1750
  case LLAMA_POOLING_TYPE_NONE:
1155
1751
  {
1156
1752
  // extract token embeddings
1157
- GGML_ASSERT(embd != nullptr);
1158
- float * embd_out = embd + n_outputs_prev*n_embd;
1753
+ GGML_ASSERT(embd.data != nullptr);
1754
+ const uint32_t n_embd_out = hparams.n_embd_out();
1755
+ float * embd_out = embd.data + n_outputs_prev*n_embd_out;
1159
1756
 
1160
1757
  if (n_outputs) {
1161
1758
  GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1162
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1163
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
1759
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size);
1760
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
1164
1761
  }
1165
1762
  } break;
1166
1763
  case LLAMA_POOLING_TYPE_MEAN:
@@ -1200,6 +1797,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
1200
1797
  }
1201
1798
  }
1202
1799
 
1800
+ // Copy backend sampling output if this ubatch produced any sampling tensors.
1801
+ if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
1802
+ const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
1803
+ const auto stride = n_vocab;
1804
+
1805
+ // async copy the sampling data from the backend to the host
1806
+ copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
1807
+
1808
+ copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
1809
+ copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
1810
+ copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
1811
+ }
1812
+
1203
1813
  n_outputs_prev += n_outputs;
1204
1814
  } while (mctx->next());
1205
1815
 
@@ -1224,7 +1834,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1224
1834
 
1225
1835
  // make the outputs have the same order they had in the user-provided batch
1226
1836
  // note: this is mostly relevant for recurrent models atm
1227
- if (!sorted_output) {
1837
+ if (!sorted_output && n_outputs > 1) {
1228
1838
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1229
1839
 
1230
1840
  // TODO: is there something more efficient which also minimizes swaps?
@@ -1269,9 +1879,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1269
1879
 
1270
1880
  const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
1271
1881
 
1272
- const auto n_batch = cparams.n_batch;
1273
- const auto n_vocab = vocab.n_tokens();
1274
- const auto n_embd = hparams.n_embd;
1882
+ const auto n_batch = cparams.n_batch;
1883
+ const auto n_vocab = vocab.n_tokens();
1884
+ const auto n_embd_out = hparams.n_embd_out();
1275
1885
 
1276
1886
  bool has_logits = true;
1277
1887
  bool has_embd = cparams.embeddings;
@@ -1282,8 +1892,19 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1282
1892
  has_embd = true;
1283
1893
  }
1284
1894
 
1285
- logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1286
- embd_size = has_embd ? n_embd*n_outputs_max : 0;
1895
+
1896
+ size_t backend_float_count = 0;
1897
+ size_t backend_token_count = 0;
1898
+
1899
+ logits.size = has_logits ? n_vocab*n_outputs_max : 0;
1900
+ embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
1901
+
1902
+ // Allocate backend sampling output buffers if there are backend samplers configured.
1903
+ const bool has_sampling = !sampling.samplers.empty();
1904
+ if (has_sampling) {
1905
+ backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
1906
+ backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
1907
+ }
1287
1908
 
1288
1909
  if (output_ids.empty()) {
1289
1910
  // init, never resized afterwards
@@ -1291,7 +1912,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1291
1912
  }
1292
1913
 
1293
1914
  const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
1294
- const size_t new_size = (logits_size + embd_size) * sizeof(float);
1915
+ const size_t new_size =
1916
+ (logits.size + embd.size + backend_float_count) * sizeof(float) +
1917
+ ( backend_token_count) * sizeof(llama_token);
1295
1918
 
1296
1919
  // alloc only when more than the current capacity is required
1297
1920
  // TODO: also consider shrinking the buffer
@@ -1299,11 +1922,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1299
1922
  if (buf_output) {
1300
1923
  #ifndef NDEBUG
1301
1924
  // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1302
- LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1925
+ LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1303
1926
  #endif
1927
+ synchronize();
1928
+
1929
+ // TODO: not needed?
1304
1930
  buf_output = nullptr;
1305
- logits = nullptr;
1306
- embd = nullptr;
1931
+ logits.data = nullptr;
1932
+ embd.data = nullptr;
1307
1933
  }
1308
1934
 
1309
1935
  auto * buft = ggml_backend_cpu_buffer_type();
@@ -1322,8 +1948,50 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1322
1948
 
1323
1949
  float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
1324
1950
 
1325
- logits = has_logits ? output_base : nullptr;
1326
- embd = has_embd ? output_base + logits_size : nullptr;
1951
+ size_t offset = 0;
1952
+ uint8_t * base = (uint8_t *) output_base;
1953
+
1954
+ logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0};
1955
+ offset += logits.size * sizeof(float);
1956
+
1957
+ embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
1958
+ offset += embd.size * sizeof(float);
1959
+
1960
+ if (has_sampling) {
1961
+ sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
1962
+ offset += sampling.logits.size * sizeof(float);
1963
+
1964
+ sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
1965
+ offset += sampling.probs.size * sizeof(float);
1966
+
1967
+ sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
1968
+ offset += sampling.sampled.size * sizeof(llama_token);
1969
+
1970
+ sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
1971
+ offset += sampling.candidates.size * sizeof(llama_token);
1972
+
1973
+ // The count vectors keep track of the actual number of logits/probs/candidates
1974
+ // copied from the backend for each output row.
1975
+
1976
+ sampling.logits_count.resize(n_outputs_max);
1977
+ sampling.probs_count.resize(n_outputs_max);
1978
+ sampling.candidates_count.resize(n_outputs_max);
1979
+
1980
+ std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
1981
+ std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
1982
+ std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
1983
+
1984
+ std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
1985
+ } else {
1986
+ sampling.logits = {nullptr, 0};
1987
+ sampling.probs = {nullptr, 0};
1988
+ sampling.sampled = {nullptr, 0};
1989
+ sampling.candidates = {nullptr, 0};
1990
+
1991
+ sampling.logits_count.clear();
1992
+ sampling.probs_count.clear();
1993
+ sampling.candidates_count.clear();
1994
+ }
1327
1995
 
1328
1996
  // set all ids as invalid (negative)
1329
1997
  std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1341,16 +2009,43 @@ void llama_context::output_reorder() {
1341
2009
  const uint64_t i0 = output_swaps[s].i0;
1342
2010
  const uint64_t i1 = output_swaps[s].i1;
1343
2011
 
1344
- if (logits_size > 0) {
2012
+ if (logits.size > 0) {
1345
2013
  for (uint64_t k = 0; k < n_vocab; k++) {
1346
- std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
2014
+ std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]);
1347
2015
  }
1348
2016
  }
1349
2017
 
1350
- if (embd_size > 0) {
2018
+ if (embd.size > 0) {
1351
2019
  for (uint64_t k = 0; k < n_embd; k++) {
1352
- std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
2020
+ std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]);
2021
+ }
2022
+ }
2023
+
2024
+ if (!sampling.samplers.empty()) {
2025
+ assert(sampling.logits.size > 0);
2026
+ assert(sampling.probs.size > 0);
2027
+ assert(sampling.candidates.size > 0);
2028
+ assert(sampling.sampled.size > 0);
2029
+ assert(sampling.logits_count.size() > 0);
2030
+ assert(sampling.probs_count.size() > 0);
2031
+ assert(sampling.candidates_count.size() > 0);
2032
+
2033
+ for (uint64_t k = 0; k < n_vocab; ++k) {
2034
+ std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
2035
+ }
2036
+
2037
+ for (uint64_t k = 0; k < n_vocab; ++k) {
2038
+ std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
2039
+ }
2040
+
2041
+ for (uint64_t k = 0; k < n_vocab; ++k) {
2042
+ std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
1353
2043
  }
2044
+
2045
+ std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
2046
+ std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
2047
+ std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
2048
+ std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
1354
2049
  }
1355
2050
  }
1356
2051
 
@@ -1361,28 +2056,36 @@ void llama_context::output_reorder() {
1361
2056
  // graph
1362
2057
  //
1363
2058
 
1364
- uint32_t llama_context::graph_max_nodes() const {
1365
- return std::max<uint32_t>(1024u, 8u*model.n_tensors());
2059
+ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
2060
+ if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
2061
+ return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
2062
+ }
2063
+ uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
2064
+ for (const auto & lora : model.loras) {
2065
+ res += lora->get_n_nodes();
2066
+ }
2067
+ return res;
1366
2068
  }
1367
2069
 
1368
2070
  llm_graph_result * llama_context::get_gf_res_reserve() const {
1369
2071
  return static_cast<llm_graph_result *>(gf_res_reserve.get());
1370
2072
  }
1371
2073
 
1372
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
2074
+ ggml_cgraph * llama_context::graph_reserve(
2075
+ uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
1373
2076
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1374
2077
  GGML_ASSERT(n_outputs >= 1);
1375
2078
 
1376
2079
  if (n_tokens % n_seqs != 0) {
1377
2080
  n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1378
- n_outputs = std::min(n_outputs, n_tokens);
2081
+ n_outputs = std::max(n_outputs, n_tokens);
1379
2082
 
1380
2083
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1381
2084
  }
1382
2085
 
1383
2086
  ggml_backend_sched_reset(sched.get());
1384
2087
 
1385
- // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
2088
+ // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that
1386
2089
  gf_res_prev->reset();
1387
2090
 
1388
2091
  // store the n_outputs as it is, and restore it afterwards
@@ -1394,6 +2097,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1394
2097
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1395
2098
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1396
2099
 
2100
+ // set one output token per sequence in order to activate all backend samplers
2101
+ std::vector<llama_seq_id> seq_ids(n_seqs);
2102
+ for (uint32_t i = 0; i < n_seqs; ++i) {
2103
+ seq_ids[i] = i;
2104
+ ubatch.n_seq_id[i] = 1;
2105
+ ubatch.seq_id[i] = &seq_ids[i];
2106
+ ubatch.output[i] = true;
2107
+ }
2108
+
1397
2109
  auto * res = gf_res_reserve.get();
1398
2110
 
1399
2111
  const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
@@ -1406,8 +2118,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1406
2118
 
1407
2119
  // initialize scheduler with the specified graph
1408
2120
  if (split_only) {
1409
- ggml_backend_sched_split_graph(sched.get(), gf);
2121
+ if (sizes) {
2122
+ ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
2123
+ } else {
2124
+ ggml_backend_sched_split_graph(sched.get(), gf);
2125
+ }
1410
2126
  } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
2127
+ GGML_ASSERT(!sizes);
1411
2128
  LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1412
2129
  return nullptr;
1413
2130
  }
@@ -1419,7 +2136,7 @@ llm_graph_params llama_context::graph_params(
1419
2136
  llm_graph_result * res,
1420
2137
  const llama_ubatch & ubatch,
1421
2138
  const llama_memory_context_i * mctx,
1422
- llm_graph_type gtype) const {
2139
+ llm_graph_type gtype) const {
1423
2140
  return {
1424
2141
  /*.arch =*/ model.arch,
1425
2142
  /*.hparams =*/ model.hparams,
@@ -1428,10 +2145,11 @@ llm_graph_params llama_context::graph_params(
1428
2145
  /*.gtype =*/ gtype,
1429
2146
  /*.sched =*/ sched.get(),
1430
2147
  /*.backend_cpu =*/ backend_cpu,
1431
- /*.cvec =*/ &cvec,
1432
- /*.loras =*/ &loras,
2148
+ /*.cvec =*/ cvec.get(),
2149
+ /*.loras =*/ loras.get(),
1433
2150
  /*.mctx =*/ mctx,
1434
2151
  /*.cross =*/ &cross,
2152
+ /*.samplers =*/ sampling.samplers,
1435
2153
  /*.n_outputs =*/ n_outputs,
1436
2154
  /*.cb =*/ graph_get_cb(),
1437
2155
  /*.res =*/ res,
@@ -1475,16 +2193,9 @@ llm_graph_cb llama_context::graph_get_cb() const {
1475
2193
  ggml_set_name(cur, name);
1476
2194
  }
1477
2195
 
1478
- if (!cparams.offload_kqv) {
1479
- if (strcmp(name, "kqv_merged_cont") == 0) {
1480
- // all nodes between the KV store and the attention output are run on the CPU
1481
- ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
1482
- }
1483
- }
1484
-
1485
2196
  // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1486
2197
  // FIXME: fix in ggml_backend_sched
1487
- const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
2198
+ const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
1488
2199
  if (ubatch.n_tokens < 32 || full_offload) {
1489
2200
  if (il != -1 && strcmp(name, "norm") == 0) {
1490
2201
  const auto & dev_layer = model.dev_layer(il);
@@ -1833,60 +2544,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1833
2544
  // TODO: add more model-specific info which should prevent loading the session file if not identical
1834
2545
  }
1835
2546
 
1836
- // write output ids
1837
- {
1838
- LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
1839
-
1840
- const auto n_outputs = this->n_outputs;
1841
- const auto & output_ids = this->output_ids;
1842
-
1843
- std::vector<int32_t> w_output_pos;
1844
-
1845
- w_output_pos.resize(n_outputs);
1846
-
1847
- // build a more compact representation of the output ids
1848
- for (size_t i = 0; i < n_batch(); ++i) {
1849
- // map an output id to a position in the batch
1850
- int64_t pos = output_ids[i];
1851
- if (pos >= 0) {
1852
- GGML_ASSERT(pos < n_outputs);
1853
- w_output_pos[pos] = i;
1854
- }
1855
- }
1856
-
1857
- io.write(&n_outputs, sizeof(n_outputs));
1858
-
1859
- if (n_outputs) {
1860
- io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
1861
- }
1862
- }
1863
-
1864
- // write logits
1865
- {
1866
- LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
1867
-
1868
- const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
1869
-
1870
- io.write(&logits_size, sizeof(logits_size));
1871
-
1872
- if (logits_size) {
1873
- io.write(logits, logits_size * sizeof(float));
1874
- }
1875
- }
1876
-
1877
- // write embeddings
1878
- {
1879
- LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
1880
-
1881
- const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
1882
-
1883
- io.write(&embd_size, sizeof(embd_size));
1884
-
1885
- if (embd_size) {
1886
- io.write(embd, embd_size * sizeof(float));
1887
- }
1888
- }
1889
-
1890
2547
  if (memory != nullptr) {
1891
2548
  LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1892
2549
  memory->state_write(io);
@@ -1912,67 +2569,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1912
2569
  // TODO: add more info which needs to be identical but which is not verified otherwise
1913
2570
  }
1914
2571
 
1915
- // read output ids
1916
- {
1917
- LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
1918
-
1919
- auto n_outputs = this->n_outputs;
1920
- io.read_to(&n_outputs, sizeof(n_outputs));
1921
-
1922
- if (n_outputs > output_reserve(n_outputs)) {
1923
- throw std::runtime_error("could not reserve outputs");
1924
- }
1925
-
1926
- std::vector<int32_t> output_pos;
1927
-
1928
- if (n_outputs) {
1929
- output_pos.resize(n_outputs);
1930
- io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
1931
-
1932
- for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
1933
- int32_t id = output_pos[i];
1934
- if ((uint32_t) id >= n_batch()) {
1935
- throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
1936
- }
1937
- this->output_ids[id] = i;
1938
- }
1939
-
1940
- this->n_outputs = n_outputs;
1941
- }
1942
- }
1943
-
1944
- // read logits
1945
- {
1946
- LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
1947
-
1948
- uint64_t logits_size;
1949
- io.read_to(&logits_size, sizeof(logits_size));
1950
-
1951
- if (this->logits_size < logits_size) {
1952
- throw std::runtime_error("logits buffer too small");
1953
- }
1954
-
1955
- if (logits_size) {
1956
- io.read_to(this->logits, logits_size * sizeof(float));
1957
- }
1958
- }
1959
-
1960
- // read embeddings
1961
- {
1962
- LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
1963
-
1964
- uint64_t embd_size;
1965
- io.read_to(&embd_size, sizeof(embd_size));
1966
-
1967
- if (this->embd_size < embd_size) {
1968
- throw std::runtime_error("embeddings buffer too small");
1969
- }
1970
-
1971
- if (embd_size) {
1972
- io.read_to(this->embd, embd_size * sizeof(float));
1973
- }
1974
- }
1975
-
1976
2572
  if (memory) {
1977
2573
  LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
1978
2574
 
@@ -2029,15 +2625,26 @@ void llama_context::perf_reset() {
2029
2625
 
2030
2626
  std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
2031
2627
  std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
2032
- for (const auto & buft_size : model.memory_breakdown()) {
2033
- ret[buft_size.first].model += buft_size.second;
2628
+ for (const auto & [buft, size] : model.memory_breakdown()) {
2629
+ ret[buft].model += size;
2034
2630
  }
2035
- for (const auto & buft_size : memory->memory_breakdown()) {
2036
- ret[buft_size.first].context += buft_size.second;
2631
+ if (memory) {
2632
+ for (const auto & [buft, size] : memory->memory_breakdown()) {
2633
+ ret[buft].context += size;
2634
+ }
2037
2635
  }
2038
- for (const auto & backend_ptr : backends) {
2039
- ggml_backend_t backend = backend_ptr.get();
2040
- ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
2636
+ if (model.hparams.no_alloc) {
2637
+ for (size_t i = 0; i < backends.size(); ++i) {
2638
+ ggml_backend_t backend = backends[i].get();
2639
+ ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
2640
+ ret[buft].compute += backend_buf_exp_size[i];
2641
+ }
2642
+ } else {
2643
+ for (const auto & backend_ptr : backends) {
2644
+ ggml_backend_t backend = backend_ptr.get();
2645
+ ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
2646
+ ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
2647
+ }
2041
2648
  }
2042
2649
  return ret;
2043
2650
  }
@@ -2094,6 +2701,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
2094
2701
  llama_set_param(model->cls_b, param_filter, param_filter_ud);
2095
2702
  llama_set_param(model->cls_out, param_filter, param_filter_ud);
2096
2703
  llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
2704
+ llama_set_param(model->cls_norm, param_filter, param_filter_ud);
2097
2705
 
2098
2706
  for (struct llama_layer & layer : model->layers) {
2099
2707
  for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
@@ -2130,7 +2738,7 @@ void llama_context::opt_epoch_iter(
2130
2738
  batch.logits [pos_batch] = true;
2131
2739
  }
2132
2740
 
2133
- if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2741
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2134
2742
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2135
2743
  return;
2136
2744
  }
@@ -2185,7 +2793,7 @@ void llama_context::opt_epoch_iter(
2185
2793
  };
2186
2794
  ctx_compute_opt = ggml_init(params);
2187
2795
  }
2188
- ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2796
+ ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
2189
2797
  ggml_opt_alloc(opt_ctx, train);
2190
2798
 
2191
2799
  res->set_inputs(&ubatch);
@@ -2295,6 +2903,8 @@ llama_context_params llama_context_default_params() {
2295
2903
  /*.op_offload =*/ true,
2296
2904
  /*.swa_full =*/ true,
2297
2905
  /*.kv_unified =*/ false,
2906
+ /*.sampler =*/ nullptr,
2907
+ /*.n_sampler =*/ 0,
2298
2908
  };
2299
2909
 
2300
2910
  return result;
@@ -2325,19 +2935,23 @@ llama_context * llama_init_from_model(
2325
2935
 
2326
2936
  if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2327
2937
  const uint32_t blck_size = ggml_blck_size(params.type_k);
2328
- if (model->hparams.n_embd_head_k % blck_size != 0) {
2329
- LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2330
- __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2331
- return nullptr;
2938
+ for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
2939
+ if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
2940
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2941
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
2942
+ return nullptr;
2943
+ }
2332
2944
  }
2333
2945
  }
2334
2946
 
2335
2947
  if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2336
2948
  const uint32_t blck_size = ggml_blck_size(params.type_v);
2337
- if (model->hparams.n_embd_head_v % blck_size != 0) {
2338
- LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2339
- __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2340
- return nullptr;
2949
+ for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
2950
+ if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
2951
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
2952
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
2953
+ return nullptr;
2954
+ }
2341
2955
  }
2342
2956
  }
2343
2957
 
@@ -2346,6 +2960,13 @@ llama_context * llama_init_from_model(
2346
2960
  return nullptr;
2347
2961
  }
2348
2962
 
2963
+ if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
2964
+ params.pooling_type != model->hparams.pooling_type) {
2965
+ //user-specified pooling-type is different from the model default
2966
+ LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
2967
+ model->hparams.pooling_type, params.pooling_type);
2968
+ }
2969
+
2349
2970
  try {
2350
2971
  auto * ctx = new llama_context(*model, params);
2351
2972
  return ctx;
@@ -2371,6 +2992,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
2371
2992
  return ctx->n_ctx();
2372
2993
  }
2373
2994
 
2995
+ uint32_t llama_n_ctx_seq(const llama_context * ctx) {
2996
+ return ctx->n_ctx_seq();
2997
+ }
2998
+
2374
2999
  uint32_t llama_n_batch(const llama_context * ctx) {
2375
3000
  return ctx->n_batch();
2376
3001
  }
@@ -2443,7 +3068,15 @@ float * llama_get_logits(llama_context * ctx) {
2443
3068
  float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
2444
3069
  ctx->synchronize();
2445
3070
 
2446
- return ctx->get_logits_ith(i);
3071
+ float * res = nullptr;
3072
+
3073
+ res = ctx->get_sampled_logits_ith(i);
3074
+
3075
+ if (!res) {
3076
+ res = ctx->get_logits_ith(i);
3077
+ }
3078
+
3079
+ return res;
2447
3080
  }
2448
3081
 
2449
3082
  float * llama_get_embeddings(llama_context * ctx) {
@@ -2464,37 +3097,89 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
2464
3097
  return ctx->get_embeddings_seq(seq_id);
2465
3098
  }
2466
3099
 
2467
- // llama adapter API
3100
+ bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
3101
+ return ctx->set_sampler(seq_id, smpl);
3102
+ }
2468
3103
 
2469
- int32_t llama_set_adapter_lora(
2470
- llama_context * ctx,
2471
- llama_adapter_lora * adapter,
2472
- float scale) {
2473
- ctx->set_adapter_lora(adapter, scale);
3104
+ llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
3105
+ ctx->synchronize();
2474
3106
 
2475
- return 0;
3107
+ return ctx->get_sampled_token_ith(i);
2476
3108
  }
2477
3109
 
2478
- int32_t llama_rm_adapter_lora(
2479
- llama_context * ctx,
2480
- llama_adapter_lora * adapter) {
2481
- bool res = ctx->rm_adapter_lora(adapter);
3110
+ float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
3111
+ ctx->synchronize();
2482
3112
 
2483
- return res ? 0 : -1;
3113
+ return ctx->get_sampled_probs_ith(i);
2484
3114
  }
2485
3115
 
2486
- void llama_clear_adapter_lora(llama_context * ctx) {
2487
- ctx->clear_adapter_lora();
3116
+ float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
3117
+ ctx->synchronize();
3118
+
3119
+ return ctx->get_sampled_logits_ith(i);
3120
+ }
3121
+
3122
+ llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
3123
+ ctx->synchronize();
3124
+
3125
+ return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
3126
+ }
3127
+
3128
+ uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
3129
+ ctx->synchronize();
3130
+
3131
+ return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
3132
+ }
3133
+
3134
+ uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
3135
+ ctx->synchronize();
3136
+
3137
+ return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
3138
+ }
3139
+
3140
+ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
3141
+ ctx->synchronize();
3142
+
3143
+ return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
3144
+ }
3145
+
3146
+ struct ggml_cgraph * llama_graph_reserve(
3147
+ struct llama_context * ctx,
3148
+ uint32_t n_tokens,
3149
+ uint32_t n_seqs,
3150
+ uint32_t n_outputs) {
3151
+ auto * memory = ctx->get_memory();
3152
+ llama_memory_context_ptr mctx;
3153
+ if (memory) {
3154
+ mctx = memory->init_full();
3155
+ }
3156
+ return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get());
3157
+ }
3158
+
3159
+ // llama adapter API
3160
+
3161
+ int32_t llama_set_adapters_lora(
3162
+ llama_context * ctx,
3163
+ llama_adapter_lora ** adapters,
3164
+ size_t n_adapters,
3165
+ float * scales) {
3166
+ if (adapters == nullptr || scales == nullptr) {
3167
+ GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call");
3168
+ }
3169
+
3170
+ ctx->set_adapters_lora(adapters, n_adapters, scales);
3171
+
3172
+ return 0;
2488
3173
  }
2489
3174
 
2490
- int32_t llama_apply_adapter_cvec(
3175
+ int32_t llama_set_adapter_cvec(
2491
3176
  llama_context * ctx,
2492
- const float * data,
2493
- size_t len,
2494
- int32_t n_embd,
2495
- int32_t il_start,
2496
- int32_t il_end) {
2497
- bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
3177
+ const float * data,
3178
+ size_t len,
3179
+ int32_t n_embd,
3180
+ int32_t il_start,
3181
+ int32_t il_end) {
3182
+ bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end);
2498
3183
 
2499
3184
  return res ? 0 : -1;
2500
3185
  }