whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -7,6 +7,7 @@
7
7
  #include "llama-memory.h"
8
8
  #include "llama-mmap.h"
9
9
  #include "llama-model.h"
10
+ #include "llama-ext.h"
10
11
 
11
12
  #include <cinttypes>
12
13
  #include <cmath>
@@ -22,6 +23,8 @@ llama_context::llama_context(
22
23
  const llama_model & model,
23
24
  llama_context_params params) :
24
25
  model(model),
26
+ cvec(std::make_unique<llama_adapter_cvec>()),
27
+ loras(std::make_unique<llama_adapter_loras>()),
25
28
  balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
26
29
  // TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
27
30
  // may need to be backend-dependent
@@ -146,6 +149,11 @@ llama_context::llama_context(
146
149
  }
147
150
 
148
151
  cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
152
+ cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
153
+
154
+ cparams.fused_gdn_ar = true;
155
+ cparams.fused_gdn_ch = true;
156
+ cparams.auto_fgdn = true;
149
157
 
150
158
  // with causal attention, the batch size is limited by the context size
151
159
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -155,6 +163,9 @@ llama_context::llama_context(
155
163
  cparams.op_offload = params.op_offload;
156
164
  cparams.kv_unified = params.kv_unified;
157
165
 
166
+ // initialized later
167
+ cparams.pipeline_parallel = false;
168
+
158
169
  {
159
170
  const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
160
171
  graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -249,11 +260,7 @@ llama_context::llama_context(
249
260
 
250
261
  // graph outputs buffer
251
262
  {
252
- // resized during inference when a batch uses more outputs
253
- // Create a dummy batch for initialization.
254
- llama_batch dummy_batch = {};
255
- dummy_batch.n_tokens = 0;
256
- if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
263
+ if (output_reserve(params.n_seq_max) < params.n_seq_max) {
257
264
  throw std::runtime_error("failed to reserve initial output buffer");
258
265
  }
259
266
 
@@ -302,16 +309,6 @@ llama_context::llama_context(
302
309
 
303
310
  LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
304
311
 
305
- const uint32_t n_seqs = cparams.n_seq_max;
306
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
307
-
308
- const size_t max_nodes = this->graph_max_nodes(n_tokens);
309
-
310
- LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
311
-
312
- gf_res_prev.reset(new llm_graph_result(max_nodes));
313
- gf_res_reserve.reset(new llm_graph_result(max_nodes));
314
-
315
312
  // TODO: move these checks to ggml_backend_sched
316
313
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
317
314
  bool pipeline_parallel =
@@ -327,6 +324,7 @@ llama_context::llama_context(
327
324
  auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
328
325
  if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
329
326
  // ignore CPU backend
327
+ // TODO: should we ignore ACCEL types too?
330
328
  continue;
331
329
  }
332
330
  auto * dev = ggml_backend_get_device(backend.get());
@@ -340,177 +338,308 @@ llama_context::llama_context(
340
338
  }
341
339
  }
342
340
 
343
- sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
341
+ cparams.pipeline_parallel = pipeline_parallel;
344
342
 
345
- if (pipeline_parallel) {
346
- LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
343
+ if (cparams.pipeline_parallel) {
344
+ LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__);
345
+
346
+ if (!graph_reuse_disable) {
347
+ // TODO: figure out a way to make graph reuse work with pipeline parallelism
348
+ // ref: https://github.com/ggml-org/llama.cpp/pull/20463
349
+ LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__);
350
+
351
+ graph_reuse_disable = true;
352
+ }
347
353
  }
348
354
 
349
- llama_memory_context_ptr mctx;
350
- if (memory) {
351
- LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
352
- mctx = memory->init_full();
353
- if (!mctx) {
354
- throw std::runtime_error("failed to initialize memory module");
355
+ sched_reserve();
356
+
357
+ if (!cparams.flash_attn) {
358
+ if (ggml_is_quantized(params.type_v)) {
359
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
355
360
  }
356
361
  }
362
+ }
363
+
364
+ // Initialize the full vocabulary token ids for backend samplers.
365
+ {
366
+ const int n_vocab = model.vocab.n_tokens();
357
367
 
358
- cross.v_embd.clear();
368
+ sampling.token_ids_full_vocab.resize(n_vocab);
369
+ for (int i = 0; i < n_vocab; ++i) {
370
+ sampling.token_ids_full_vocab[i] = i;
371
+ }
372
+ }
373
+ }
359
374
 
360
- // avoid reserving graphs with zero outputs - assume one output per sequence
361
- n_outputs = n_seqs;
375
+ llama_context::~llama_context() {
376
+ if (!model.hparams.no_alloc) {
377
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
378
+ ggml_backend_t backend = backend_ptrs[i];
379
+ ggml_backend_buffer_type_t buft = backend_buft[i];
362
380
 
363
- LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
381
+ const size_t size_exp = backend_buf_exp_size[i];
382
+ const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
383
+ if (size_exp == size_act) {
384
+ LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
385
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
386
+ } else {
387
+ LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
388
+ __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
389
+ }
390
+ }
391
+ }
392
+ ggml_opt_free(opt_ctx);
393
+ }
364
394
 
365
- // resolve automatic Flash Attention use
366
- if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
395
+ void llama_context::sched_reserve() {
396
+ if (!sched_need_reserve) {
397
+ return;
398
+ }
399
+
400
+ sched_need_reserve = false;
401
+
402
+ LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
403
+
404
+ synchronize();
405
+
406
+ const int64_t t_start_us = ggml_time_us();
407
+
408
+ const uint32_t n_seqs = cparams.n_seq_max;
409
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
410
+
411
+ const size_t max_nodes = this->graph_max_nodes(n_tokens);
412
+
413
+ LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
414
+
415
+ gf_res_prev.reset(new llm_graph_result(max_nodes));
416
+ gf_res_reserve.reset(new llm_graph_result(max_nodes));
417
+
418
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload));
419
+
420
+ llama_memory_context_ptr mctx;
421
+ if (memory) {
422
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
423
+ mctx = memory->init_full();
424
+ if (!mctx) {
425
+ throw std::runtime_error("failed to initialize memory module");
426
+ }
427
+ }
428
+
429
+ // avoid reserving graphs with zero outputs - assume one output per sequence
430
+ const int n_outputs = n_seqs;
431
+
432
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
433
+
434
+ // resolve automatic Flash Attention use
435
+ if (cparams.auto_fa) {
436
+ auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
437
+ if (!gf) {
438
+ throw std::runtime_error("failed to reserve graph for Flash Attention check");
439
+ }
440
+
441
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
442
+ bool fa_device_mismatch = false;
443
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
444
+ ggml_tensor * n = ggml_graph_node(gf, i);
445
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
446
+ continue;
447
+ }
448
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
449
+
450
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
451
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
452
+ const int il = std::stoi(n->name + prefix_len);
453
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
454
+ if (device_fa != device_kv) {
455
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
456
+ "is assigned to device %s (usually due to missing support)\n",
457
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
458
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
459
+ fa_device_mismatch = true;
460
+ break;
461
+ }
462
+ }
463
+
464
+ if (fa_device_mismatch) {
465
+ cparams.flash_attn = false;
466
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
467
+ } else {
468
+ cparams.flash_attn = true;
469
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
470
+ }
471
+
472
+ cparams.auto_fa = false;
473
+ }
474
+
475
+ if (cparams.auto_fgdn) {
476
+ LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
477
+
478
+ if (cparams.fused_gdn_ar) {
367
479
  auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
368
480
  if (!gf) {
369
- throw std::runtime_error("failed to split graph for Flash Attention check");
481
+ throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
370
482
  }
371
483
 
372
- const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
373
- bool fa_device_mismatch = false;
484
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
485
+ bool gdn_device_mismatch = false;
374
486
  for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
375
487
  ggml_tensor * n = ggml_graph_node(gf, i);
376
- if (n->op != GGML_OP_FLASH_ATTN_EXT) {
488
+ if (n->op != GGML_OP_GATED_DELTA_NET) {
377
489
  continue;
378
490
  }
379
- ggml_backend_dev_t device_fa = ggml_backend_get_device(
380
- ggml_backend_sched_get_tensor_backend(sched.get(), n));
491
+ ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
381
492
 
382
- // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
383
- GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
493
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
384
494
  const int il = std::stoi(n->name + prefix_len);
385
495
  ggml_backend_dev_t device_kv = model.dev_layer(il);
386
- if (device_fa != device_kv) {
387
- LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
388
- "is assigned to device %s (usually due to missing support)\n",
389
- __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
390
- // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
391
- fa_device_mismatch = true;
496
+ if (device_gdn != device_kv) {
497
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
498
+ "is assigned to device %s (usually due to missing support)\n",
499
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
500
+ gdn_device_mismatch = true;
392
501
  break;
393
502
  }
394
503
  }
395
- if (fa_device_mismatch) {
396
- cparams.flash_attn = false;
397
- LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
398
- if (ggml_is_quantized(params.type_v)) {
399
- throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
400
- }
504
+
505
+ if (gdn_device_mismatch) {
506
+ cparams.fused_gdn_ar = false;
507
+ LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
401
508
  } else {
402
- cparams.flash_attn = true;
403
- LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
509
+ LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
404
510
  }
405
511
  }
406
512
 
407
- // reserve worst-case graph
408
- int n_splits_pp = -1;
409
- int n_nodes_pp = -1;
410
-
411
- int n_splits_tg = -1;
412
- int n_nodes_tg = -1;
413
-
414
- // reserve pp (prompt processing) graph first so that buffers are only allocated once
415
- {
416
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
417
- model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
513
+ if (cparams.fused_gdn_ch) {
514
+ // more than one token in the batch per sequence in order to take the chunked path
515
+ // note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
516
+ // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
517
+ // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
518
+ // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
519
+ const uint32_t n_tokens_ch = 16*n_seqs;
520
+ auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
418
521
  if (!gf) {
419
- if (pipeline_parallel) {
420
- LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
421
- sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
422
- gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
522
+ throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
523
+ }
524
+
525
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
526
+ bool gdn_device_mismatch = false;
527
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
528
+ ggml_tensor * n = ggml_graph_node(gf, i);
529
+ if (n->op != GGML_OP_GATED_DELTA_NET) {
530
+ continue;
423
531
  }
424
- if (!gf) {
425
- throw std::runtime_error("failed to allocate compute pp buffers");
532
+ ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
533
+
534
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
535
+ const int il = std::stoi(n->name + prefix_len);
536
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
537
+ if (device_gdn != device_kv) {
538
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
539
+ "is assigned to device %s (usually due to missing support)\n",
540
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
541
+ gdn_device_mismatch = true;
542
+ break;
426
543
  }
427
544
  }
428
545
 
429
- n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
430
- n_nodes_pp = ggml_graph_n_nodes(gf);
546
+ if (gdn_device_mismatch) {
547
+ cparams.fused_gdn_ch = false;
548
+ LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
549
+ } else {
550
+ LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
551
+ }
431
552
  }
432
553
 
433
- // reserve with tg (token generation) graph to get the number of splits and nodes
434
- {
435
- auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
436
- if (!gf) {
437
- throw std::runtime_error("failed to allocate compute tg buffers");
438
- }
554
+ cparams.auto_fgdn = false;
555
+ }
439
556
 
440
- n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
441
- n_nodes_tg = ggml_graph_n_nodes(gf);
442
- }
557
+ // reserve worst-case graph
558
+ int n_splits_pp = -1;
559
+ int n_nodes_pp = -1;
443
560
 
444
- // reserve again with pp graph to avoid ggml-alloc reallocations during inference
445
- {
446
- // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
447
- //
448
- // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
449
- //
450
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
561
+ int n_splits_tg = -1;
562
+ int n_nodes_tg = -1;
563
+
564
+ // reserve pp (prompt processing) graph first so that buffers are only allocated once
565
+ {
566
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
567
+ model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
568
+ if (!gf) {
569
+ if (cparams.pipeline_parallel) {
570
+ LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
571
+ cparams.pipeline_parallel = false;
572
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
573
+ gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
574
+ }
451
575
  if (!gf) {
452
576
  throw std::runtime_error("failed to allocate compute pp buffers");
453
577
  }
454
578
  }
455
579
 
456
- for (size_t i = 0; i < backend_ptrs.size(); ++i) {
457
- ggml_backend_t backend = backend_ptrs[i];
458
- ggml_backend_buffer_type_t buft = backend_buft[i];
459
- if (!model.hparams.no_alloc) {
460
- backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
461
- }
462
- if (backend_buf_exp_size[i] > 1) {
463
- LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
464
- ggml_backend_buft_name(buft),
465
- backend_buf_exp_size[i] / 1024.0 / 1024.0);
466
- }
467
- }
580
+ n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
581
+ n_nodes_pp = ggml_graph_n_nodes(gf);
582
+ }
468
583
 
469
- if (n_nodes_pp == n_nodes_tg) {
470
- LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
471
- } else {
472
- LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
584
+ // reserve with tg (token generation) graph to get the number of splits and nodes
585
+ {
586
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
587
+ if (!gf) {
588
+ throw std::runtime_error("failed to allocate compute tg buffers");
473
589
  }
474
590
 
475
- if (n_splits_pp == n_splits_tg) {
476
- LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
477
- } else {
478
- LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
479
- }
591
+ n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
592
+ n_nodes_tg = ggml_graph_n_nodes(gf);
480
593
  }
481
594
 
482
- // Initialize the full vocabulary token ids for backend samplers.
595
+ // reserve again with pp graph to avoid ggml-alloc reallocations during inference
483
596
  {
484
- const int n_vocab = model.vocab.n_tokens();
597
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
598
+ //
599
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
600
+ //
601
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
602
+ if (!gf) {
603
+ throw std::runtime_error("failed to allocate compute pp buffers");
604
+ }
605
+ }
485
606
 
486
- sampling.token_ids_full_vocab.resize(n_vocab);
487
- for (int i = 0; i < n_vocab; ++i) {
488
- sampling.token_ids_full_vocab[i] = i;
607
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
608
+ ggml_backend_t backend = backend_ptrs[i];
609
+ ggml_backend_buffer_type_t buft = backend_buft[i];
610
+ if (!model.hparams.no_alloc) {
611
+ backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
612
+ }
613
+ if (backend_buf_exp_size[i] > 1) {
614
+ LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
615
+ ggml_backend_buft_name(buft),
616
+ backend_buf_exp_size[i] / 1024.0 / 1024.0);
489
617
  }
490
618
  }
491
- }
492
619
 
493
- llama_context::~llama_context() {
494
- if (!model.hparams.no_alloc) {
495
- for (size_t i = 0; i < backend_ptrs.size(); ++i) {
496
- ggml_backend_t backend = backend_ptrs[i];
497
- ggml_backend_buffer_type_t buft = backend_buft[i];
620
+ if (n_nodes_pp == n_nodes_tg) {
621
+ LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
622
+ } else {
623
+ LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
624
+ }
498
625
 
499
- const size_t size_exp = backend_buf_exp_size[i];
500
- const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
501
- if (size_exp == size_act) {
502
- LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
503
- __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
504
- } else {
505
- LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
506
- __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
507
- }
508
- }
626
+ if (n_splits_pp == n_splits_tg) {
627
+ LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
628
+ } else {
629
+ LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
509
630
  }
510
- ggml_opt_free(opt_ctx);
631
+
632
+ const int64_t t_end_us = ggml_time_us();
633
+
634
+ LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n",
635
+ __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get()));
511
636
  }
512
637
 
513
638
  void llama_context::synchronize() {
639
+ if (!sched) {
640
+ return;
641
+ }
642
+
514
643
  ggml_backend_sched_synchronize(sched.get());
515
644
 
516
645
  // FIXME: if multiple single tokens are evaluated without a synchronization,
@@ -645,7 +774,7 @@ enum llama_pooling_type llama_context::pooling_type() const {
645
774
  float * llama_context::get_logits() {
646
775
  output_reorder();
647
776
 
648
- return logits;
777
+ return logits.data;
649
778
  }
650
779
 
651
780
  int64_t llama_context::output_resolve_row(int32_t i) const {
@@ -678,36 +807,15 @@ int64_t llama_context::output_resolve_row(int32_t i) const {
678
807
  }
679
808
 
680
809
  float * llama_context::get_logits_ith(int32_t i) {
681
- int64_t j = -1;
682
-
683
810
  output_reorder();
684
811
 
685
812
  try {
686
- if (logits == nullptr) {
813
+ if (logits.data == nullptr) {
687
814
  throw std::runtime_error("no logits");
688
815
  }
689
816
 
690
- // TODO: use output_resolve_row()
691
- if (i < 0) {
692
- j = n_outputs + i;
693
- if (j < 0) {
694
- throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
695
- }
696
- } else if ((size_t) i >= output_ids.size()) {
697
- throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
698
- } else {
699
- j = output_ids[i];
700
- }
701
-
702
- if (j < 0) {
703
- throw std::runtime_error(format("batch.logits[%d] != true", i));
704
- }
705
- if (j >= n_outputs) {
706
- // This should not happen
707
- throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
708
- }
709
-
710
- return logits + j*model.vocab.n_tokens();
817
+ const int64_t j = output_resolve_row(i);
818
+ return logits.data + j*model.vocab.n_tokens();
711
819
  } catch (const std::exception & err) {
712
820
  LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
713
821
  #ifndef NDEBUG
@@ -721,45 +829,24 @@ float * llama_context::get_logits_ith(int32_t i) {
721
829
  float * llama_context::get_embeddings() {
722
830
  output_reorder();
723
831
 
724
- return embd;
832
+ return embd.data;
725
833
  }
726
834
 
727
835
  llama_token * llama_context::get_sampled_tokens() const{
728
- return sampling.sampled;
836
+ return sampling.sampled.data;
729
837
  }
730
838
 
731
839
  float * llama_context::get_embeddings_ith(int32_t i) {
732
- int64_t j = -1;
733
-
734
840
  output_reorder();
735
841
 
736
842
  try {
737
- if (embd == nullptr) {
843
+ if (embd.data == nullptr) {
738
844
  throw std::runtime_error("no embeddings");
739
845
  }
740
846
 
741
- // TODO: use output_resolve_row()
742
- if (i < 0) {
743
- j = n_outputs + i;
744
- if (j < 0) {
745
- throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
746
- }
747
- } else if ((size_t) i >= output_ids.size()) {
748
- throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
749
- } else {
750
- j = output_ids[i];
751
- }
752
-
753
- if (j < 0) {
754
- throw std::runtime_error(format("batch.logits[%d] != true", i));
755
- }
756
- if (j >= n_outputs) {
757
- // This should not happen
758
- throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
759
- }
760
-
761
- const uint32_t n_embd_out = model.hparams.get_n_embd_out();
762
- return embd + j*n_embd_out;
847
+ const int64_t j = output_resolve_row(i);
848
+ const uint32_t n_embd_out = model.hparams.n_embd_out();
849
+ return embd.data + j*n_embd_out;
763
850
  } catch (const std::exception & err) {
764
851
  LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
765
852
  #ifndef NDEBUG
@@ -782,14 +869,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
782
869
  llama_token llama_context::get_sampled_token_ith(int32_t idx) {
783
870
  output_reorder();
784
871
 
785
- if (sampling.sampled == nullptr) {
872
+ if (!sampling.sampled.has_data()) {
786
873
  return LLAMA_TOKEN_NULL;
787
874
  }
788
875
 
789
876
  try {
790
877
  const int64_t row = output_resolve_row(idx);
791
- GGML_ASSERT(row < (int64_t) sampling.sampled_size);
792
- return sampling.sampled[row];
878
+ GGML_ASSERT(row < (int64_t) sampling.sampled.size);
879
+ return sampling.sampled.data[row];
793
880
  } catch (const std::exception & err) {
794
881
  LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
795
882
  return LLAMA_TOKEN_NULL;
@@ -799,7 +886,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
799
886
  float * llama_context::get_sampled_probs_ith(int32_t idx) {
800
887
  output_reorder();
801
888
 
802
- if (sampling.probs == nullptr) {
889
+ if (!sampling.probs.has_data()) {
803
890
  return nullptr;
804
891
  }
805
892
 
@@ -808,7 +895,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
808
895
  if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
809
896
  return nullptr;
810
897
  }
811
- return sampling.probs + row*model.vocab.n_tokens();
898
+ return sampling.probs.data + row*model.vocab.n_tokens();
812
899
  } catch (const std::exception & err) {
813
900
  LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
814
901
  return nullptr;
@@ -818,7 +905,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
818
905
  float * llama_context::get_sampled_logits_ith(int32_t idx) {
819
906
  output_reorder();
820
907
 
821
- if (sampling.logits == nullptr) {
908
+ if (!sampling.logits.has_data()) {
822
909
  return nullptr;
823
910
  }
824
911
 
@@ -827,7 +914,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) {
827
914
  if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
828
915
  return nullptr;
829
916
  }
830
- return sampling.logits + row*model.vocab.n_tokens();
917
+ return sampling.logits.data + row*model.vocab.n_tokens();
831
918
  } catch (const std::exception & err) {
832
919
  LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
833
920
  return nullptr;
@@ -839,13 +926,14 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
839
926
 
840
927
  try {
841
928
  const int64_t row = output_resolve_row(idx);
842
- if (sampling.candidates != nullptr &&
929
+ if (sampling.candidates.has_data() &&
843
930
  (size_t) row < sampling.candidates_count.size() &&
844
931
  sampling.candidates_count[row] > 0) {
845
- return sampling.candidates + row*model.vocab.n_tokens();
932
+ return sampling.candidates.data + row*model.vocab.n_tokens();
846
933
  }
847
934
  } catch (const std::exception & err) {
848
935
  // fallback to full vocab list
936
+ GGML_UNUSED(err);
849
937
  }
850
938
 
851
939
  return sampling.token_ids_full_vocab.data();
@@ -854,7 +942,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
854
942
  size_t llama_context::get_sampled_candidates_count(int32_t idx) {
855
943
  output_reorder();
856
944
 
857
- if (sampling.candidates == nullptr) {
945
+ if (!sampling.candidates.has_data()) {
858
946
  return 0;
859
947
  }
860
948
 
@@ -873,7 +961,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
873
961
  size_t llama_context::get_sampled_logits_count(int32_t idx) {
874
962
  output_reorder();
875
963
 
876
- if (sampling.logits == nullptr) {
964
+ if (!sampling.logits.has_data()) {
877
965
  return model.vocab.n_tokens();
878
966
  }
879
967
 
@@ -892,7 +980,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
892
980
  size_t llama_context::get_sampled_probs_count(int32_t idx) {
893
981
  output_reorder();
894
982
 
895
- if (sampling.probs == nullptr) {
983
+ if (!sampling.probs.has_data()) {
896
984
  return 0;
897
985
  }
898
986
 
@@ -951,21 +1039,41 @@ void llama_context::set_embeddings(bool value) {
951
1039
  LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
952
1040
 
953
1041
  cparams.embeddings = value;
1042
+
1043
+ // TODO: not sure yet if we want to reserve here
1044
+ //sched_need_reserve = true;
954
1045
  }
955
1046
 
956
1047
  void llama_context::set_causal_attn(bool value) {
957
1048
  LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
958
1049
 
1050
+ if (cparams.causal_attn == value) {
1051
+ return;
1052
+ }
1053
+
959
1054
  cparams.causal_attn = value;
1055
+
1056
+ sched_need_reserve = true;
960
1057
  }
961
1058
 
962
1059
  void llama_context::set_warmup(bool value) {
963
1060
  LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
964
1061
 
1062
+ if (cparams.warmup == value) {
1063
+ return;
1064
+ }
1065
+
965
1066
  cparams.warmup = value;
1067
+
1068
+ // warmups are usually with small batches, so no need to reserve
1069
+ //sched_need_reserve = true;
966
1070
  }
967
1071
 
968
1072
  bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
1073
+ if (!sampler && sampling.samplers.count(seq_id) == 0) {
1074
+ return true;
1075
+ }
1076
+
969
1077
  LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
970
1078
 
971
1079
  const bool can_offload =
@@ -975,22 +1083,24 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
975
1083
  llama_sampler_chain_n(sampler) > 0;
976
1084
 
977
1085
  if (sampler && can_offload) {
978
- ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
979
- auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
980
- if (host_buft) {
981
- buft = host_buft;
982
- }
1086
+ auto * buft = ggml_backend_dev_buffer_type(model.dev_output());
983
1087
 
984
1088
  sampler->iface->backend_init(sampler, buft);
985
1089
 
986
1090
  sampling.samplers[seq_id] = sampler;
987
1091
 
1092
+ sched_need_reserve = true;
1093
+
988
1094
  return true;
989
1095
  }
990
1096
 
991
1097
  if (sampler && !can_offload) {
992
1098
  LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
993
1099
 
1100
+ if (sampling.samplers.count(seq_id) > 0) {
1101
+ sched_need_reserve = true;
1102
+ }
1103
+
994
1104
  sampling.samplers.erase(seq_id);
995
1105
 
996
1106
  return false;
@@ -998,37 +1108,56 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
998
1108
 
999
1109
  sampling.samplers.erase(seq_id);
1000
1110
 
1111
+ sched_need_reserve = true;
1112
+
1001
1113
  return true;
1002
1114
  }
1003
1115
 
1004
- void llama_context::set_adapter_lora(
1005
- llama_adapter_lora * adapter,
1006
- float scale) {
1007
- LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
1116
+ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
1117
+ LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
1008
1118
 
1009
- loras[adapter] = scale;
1010
- }
1119
+ if (adapters_lora_are_same(adapters, n_adapters, scales)) {
1120
+ return;
1121
+ }
1011
1122
 
1012
- bool llama_context::rm_adapter_lora(
1013
- llama_adapter_lora * adapter) {
1014
- LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
1123
+ loras.reset(new llama_adapter_loras());
1015
1124
 
1016
- auto pos = loras.find(adapter);
1017
- if (pos != loras.end()) {
1018
- loras.erase(pos);
1019
- return true;
1125
+ for (size_t i = 0; i < n_adapters; i ++) {
1126
+ if (scales[i] != 0.0f) {
1127
+ loras->insert({adapters[i], scales[i]});
1128
+ }
1020
1129
  }
1021
1130
 
1022
- return false;
1131
+ sched_need_reserve = true;
1023
1132
  }
1024
1133
 
1025
- void llama_context::clear_adapter_lora() {
1026
- LLAMA_LOG_DEBUG("%s: call\n", __func__);
1134
+ bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
1135
+ LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
1136
+
1137
+ // Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison.
1138
+ size_t n_non_zero = 0;
1139
+
1140
+ for (size_t i = 0; i < n_adapters; i ++) {
1141
+ if (scales[i] == 0.0f) {
1142
+ continue;
1143
+ }
1144
+ n_non_zero++;
1145
+
1146
+ auto it = loras->find(adapters[i]);
1027
1147
 
1028
- loras.clear();
1148
+ if (it == loras->end() || it->second != scales[i]) {
1149
+ return false;
1150
+ }
1151
+ }
1152
+
1153
+ if (n_non_zero != loras->size()) {
1154
+ return false;
1155
+ }
1156
+
1157
+ return true;
1029
1158
  }
1030
1159
 
1031
- bool llama_context::apply_adapter_cvec(
1160
+ bool llama_context::set_adapter_cvec(
1032
1161
  const float * data,
1033
1162
  size_t len,
1034
1163
  int32_t n_embd,
@@ -1036,7 +1165,9 @@ bool llama_context::apply_adapter_cvec(
1036
1165
  int32_t il_end) {
1037
1166
  LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
1038
1167
 
1039
- return cvec.apply(model, data, len, n_embd, il_start, il_end);
1168
+ // TODO: should we reserve?
1169
+
1170
+ return cvec->apply(model, data, len, n_embd, il_start, il_end);
1040
1171
  }
1041
1172
 
1042
1173
  llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
@@ -1086,6 +1217,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
1086
1217
  {
1087
1218
  //const auto t_start_us = ggml_time_us();
1088
1219
 
1220
+ // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated
1089
1221
  res->set_inputs(&ubatch);
1090
1222
 
1091
1223
  //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
@@ -1138,10 +1270,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
1138
1270
  // TODO: this clear of the buffer can easily be forgotten - need something better
1139
1271
  embd_seq.clear();
1140
1272
 
1273
+ sched_reserve();
1274
+
1141
1275
  n_queued_tokens += n_tokens;
1142
1276
 
1143
1277
  // reserve output buffer
1144
- if (output_reserve(n_tokens, batch_inp) < n_tokens) {
1278
+ if (output_reserve(n_tokens) < n_tokens) {
1145
1279
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
1146
1280
  return -2;
1147
1281
  };
@@ -1177,16 +1311,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
1177
1311
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
1178
1312
 
1179
1313
  // extract logits
1180
- if (logits && t_logits) {
1314
+ if (logits.data && t_logits) {
1181
1315
  ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1182
1316
  GGML_ASSERT(backend_res != nullptr);
1183
- GGML_ASSERT(logits != nullptr);
1317
+ GGML_ASSERT(logits.data != nullptr);
1184
1318
 
1185
- ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
1319
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float));
1186
1320
  }
1187
1321
 
1188
1322
  // extract embeddings
1189
- if (embd && t_embd) {
1323
+ if (embd.data && t_embd) {
1190
1324
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1191
1325
  GGML_ASSERT(backend_embd != nullptr);
1192
1326
 
@@ -1194,11 +1328,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
1194
1328
  case LLAMA_POOLING_TYPE_NONE:
1195
1329
  {
1196
1330
  // extract token embeddings
1197
- GGML_ASSERT(embd != nullptr);
1198
- const uint32_t n_embd_out = hparams.get_n_embd_out();
1331
+ GGML_ASSERT(embd.data != nullptr);
1332
+ const uint32_t n_embd_out = hparams.n_embd_out();
1199
1333
 
1200
- GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
1201
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
1334
+ GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size);
1335
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float));
1202
1336
  } break;
1203
1337
  case LLAMA_POOLING_TYPE_MEAN:
1204
1338
  case LLAMA_POOLING_TYPE_CLS:
@@ -1246,7 +1380,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
1246
1380
  cross.n_embd = t_embd->ne[0];
1247
1381
  cross.n_enc = t_embd->ne[1];
1248
1382
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
1249
- memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
1383
+ memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd));
1250
1384
 
1251
1385
  const auto & batch = balloc->get_batch();
1252
1386
 
@@ -1286,11 +1420,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
1286
1420
 
1287
1421
  static void copy_tensor_async_ints(
1288
1422
  const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1289
- llama_token * sampled,
1290
- size_t sampled_size,
1423
+ const buffer_view<llama_token> & sampled,
1291
1424
  const std::map<llama_seq_id, uint32_t> & seq_to_row,
1292
1425
  ggml_backend_sched_t sched) {
1293
- if (sampled == nullptr) {
1426
+ if (!sampled.has_data()) {
1294
1427
  return;
1295
1428
  }
1296
1429
 
@@ -1301,23 +1434,23 @@ static void copy_tensor_async_ints(
1301
1434
  }
1302
1435
 
1303
1436
  const uint32_t row = it->second;
1304
- GGML_ASSERT(row < sampled_size);
1437
+ GGML_ASSERT(row < sampled.size);
1305
1438
 
1306
1439
  GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
1307
1440
 
1308
1441
  ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1309
- ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
1442
+ ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
1310
1443
  }
1311
1444
  }
1312
1445
 
1313
1446
  static void copy_tensor_async_floats(
1314
1447
  const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1315
- float * dst,
1448
+ const buffer_view<float> & dst,
1316
1449
  size_t stride,
1317
1450
  std::vector<uint32_t> & counts,
1318
1451
  const std::map<llama_seq_id, uint32_t> & seq_to_row,
1319
1452
  ggml_backend_sched_t sched) {
1320
- if (dst == nullptr) {
1453
+ if (!dst.has_data()) {
1321
1454
  return;
1322
1455
  }
1323
1456
 
@@ -1333,7 +1466,7 @@ static void copy_tensor_async_floats(
1333
1466
  GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
1334
1467
 
1335
1468
  ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1336
- float * row_ptr = dst + (size_t) row * stride;
1469
+ float * row_ptr = dst.data + (size_t) row * stride;
1337
1470
  ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1338
1471
 
1339
1472
  // Update the actual number of logits/probabilities that were written for this row.
@@ -1343,12 +1476,12 @@ static void copy_tensor_async_floats(
1343
1476
 
1344
1477
  static void copy_tensor_async_candidates(
1345
1478
  const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
1346
- llama_token * dst,
1479
+ const buffer_view<llama_token> & dst,
1347
1480
  size_t stride,
1348
1481
  std::vector<uint32_t> & counts,
1349
1482
  const std::map<llama_seq_id, uint32_t> & seq_to_row,
1350
1483
  ggml_backend_sched_t sched) {
1351
- if (dst == nullptr) {
1484
+ if (!dst.has_data()) {
1352
1485
  return;
1353
1486
  }
1354
1487
 
@@ -1364,7 +1497,7 @@ static void copy_tensor_async_candidates(
1364
1497
  GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
1365
1498
 
1366
1499
  ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
1367
- llama_token * row_ptr = dst + (size_t) row * stride;
1500
+ llama_token * row_ptr = dst.data + (size_t) row * stride;
1368
1501
  ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
1369
1502
 
1370
1503
  // Update the actual number of candidates that were written.
@@ -1372,6 +1505,23 @@ static void copy_tensor_async_candidates(
1372
1505
  }
1373
1506
  }
1374
1507
 
1508
+ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
1509
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1510
+ if (!ubatch.output[i]) {
1511
+ continue;
1512
+ }
1513
+
1514
+ // Check if the output token has at least one sequence without a backend sampler.
1515
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
1516
+ llama_seq_id seq_id = ubatch.seq_id[i][j];
1517
+ if (samplers.find(seq_id) == samplers.end()) {
1518
+ return true;
1519
+ }
1520
+ }
1521
+ }
1522
+ return false; // all sequences use backend sampling
1523
+ }
1524
+
1375
1525
  int llama_context::decode(const llama_batch & batch_inp) {
1376
1526
  GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
1377
1527
 
@@ -1451,6 +1601,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
1451
1601
  embd_seq.clear();
1452
1602
  output_swaps.clear();
1453
1603
 
1604
+ sched_reserve();
1605
+
1454
1606
  bool did_optimize = false;
1455
1607
 
1456
1608
  // handle any pending shifts/copies
@@ -1502,7 +1654,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1502
1654
  }
1503
1655
 
1504
1656
  // reserve output buffer
1505
- if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
1657
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1506
1658
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1507
1659
  return -2;
1508
1660
  };
@@ -1575,25 +1727,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
1575
1727
  }
1576
1728
 
1577
1729
  // extract logits
1578
- // For multi-sequence batches that mix backend samplers and CPU sampler
1579
- // this is currently inefficient as we copy all logits even for the
1580
- // backend sampled tokens.
1581
- if (logits && t_logits && n_outputs > 0) {
1730
+ if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
1582
1731
  ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1583
1732
  GGML_ASSERT(backend_res != nullptr);
1584
- GGML_ASSERT(logits != nullptr);
1733
+ GGML_ASSERT(logits.data != nullptr);
1585
1734
 
1586
- float * logits_out = logits + n_outputs_prev*n_vocab;
1735
+ float * logits_out = logits.data + n_outputs_prev*n_vocab;
1587
1736
 
1588
1737
  if (n_outputs) {
1589
1738
  GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1590
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1739
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size);
1591
1740
  ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1592
1741
  }
1593
1742
  }
1594
1743
 
1595
1744
  // extract embeddings
1596
- if (embd && t_embd && n_outputs > 0) {
1745
+ if (embd.data && t_embd && n_outputs > 0) {
1597
1746
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1598
1747
  GGML_ASSERT(backend_embd != nullptr);
1599
1748
 
@@ -1601,13 +1750,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
1601
1750
  case LLAMA_POOLING_TYPE_NONE:
1602
1751
  {
1603
1752
  // extract token embeddings
1604
- GGML_ASSERT(embd != nullptr);
1605
- const uint32_t n_embd_out = hparams.get_n_embd_out();
1606
- float * embd_out = embd + n_outputs_prev*n_embd_out;
1753
+ GGML_ASSERT(embd.data != nullptr);
1754
+ const uint32_t n_embd_out = hparams.n_embd_out();
1755
+ float * embd_out = embd.data + n_outputs_prev*n_embd_out;
1607
1756
 
1608
1757
  if (n_outputs) {
1609
1758
  GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1610
- GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
1759
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size);
1611
1760
  ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
1612
1761
  }
1613
1762
  } break;
@@ -1648,16 +1797,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
1648
1797
  }
1649
1798
  }
1650
1799
 
1651
- // This flag indicates whether a backend sampler has actually sampled a specific
1652
- // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
1653
- const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
1654
-
1655
- if (has_samplers && has_sampled) {
1800
+ // Copy backend sampling output if this ubatch produced any sampling tensors.
1801
+ if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
1656
1802
  const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
1657
1803
  const auto stride = n_vocab;
1658
1804
 
1659
1805
  // async copy the sampling data from the backend to the host
1660
- copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
1806
+ copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
1661
1807
 
1662
1808
  copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
1663
1809
  copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
@@ -1727,7 +1873,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1727
1873
  // output
1728
1874
  //
1729
1875
 
1730
- uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
1876
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1731
1877
  const auto & hparams = model.hparams;
1732
1878
  const auto & vocab = model.vocab;
1733
1879
 
@@ -1735,7 +1881,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
1735
1881
 
1736
1882
  const auto n_batch = cparams.n_batch;
1737
1883
  const auto n_vocab = vocab.n_tokens();
1738
- const auto n_embd_out = hparams.get_n_embd_out();
1884
+ const auto n_embd_out = hparams.n_embd_out();
1739
1885
 
1740
1886
  bool has_logits = true;
1741
1887
  bool has_embd = cparams.embeddings;
@@ -1746,52 +1892,18 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
1746
1892
  has_embd = true;
1747
1893
  }
1748
1894
 
1749
- // Check which sampling modes are needed for the current batch.
1750
- // TODO: avoid this branching by working with the worst-case
1751
- bool has_sampling = false;
1752
- bool cpu_logits = false;
1753
-
1754
- if (batch.logits) {
1755
- for (int32_t i = 0; i < batch.n_tokens; i++) {
1756
- if (!batch.logits[i]) {
1757
- continue;
1758
- }
1759
- for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
1760
- llama_seq_id seq_id = batch.seq_id[i][j];
1761
- if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
1762
- has_sampling = true;
1763
- } else {
1764
- cpu_logits = true;
1765
- }
1766
- }
1767
- }
1768
- } else {
1769
- // When batch.logits is nullptr (when loading state with a dummy batch),
1770
- // allocate CPU logits.
1771
- cpu_logits = true;
1772
- }
1773
1895
 
1774
1896
  size_t backend_float_count = 0;
1775
1897
  size_t backend_token_count = 0;
1776
1898
 
1777
- // Allocate CPU logits buffer only if needed by sequences in this batch
1778
- logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
1779
- embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
1780
-
1781
- // TODO: avoid this branching by working with the worst-case
1782
- if (!has_sampling) {
1783
- sampling.logits_size = 0;
1784
- sampling.probs_size = 0;
1785
- sampling.sampled_size = 0;
1786
- sampling.candidates_size = 0;
1787
- } else {
1788
- sampling.logits_size = n_vocab*n_outputs_max;
1789
- sampling.probs_size = n_vocab*n_outputs_max;
1790
- sampling.sampled_size = n_outputs_max;
1791
- sampling.candidates_size = n_vocab*n_outputs_max;
1899
+ logits.size = has_logits ? n_vocab*n_outputs_max : 0;
1900
+ embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
1792
1901
 
1793
- backend_float_count = sampling.logits_size + sampling.probs_size;
1794
- backend_token_count = sampling.sampled_size + sampling.candidates_size;
1902
+ // Allocate backend sampling output buffers if there are backend samplers configured.
1903
+ const bool has_sampling = !sampling.samplers.empty();
1904
+ if (has_sampling) {
1905
+ backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
1906
+ backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
1795
1907
  }
1796
1908
 
1797
1909
  if (output_ids.empty()) {
@@ -1801,7 +1913,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
1801
1913
 
1802
1914
  const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
1803
1915
  const size_t new_size =
1804
- (logits_size + embd_size + backend_float_count) * sizeof(float) +
1916
+ (logits.size + embd.size + backend_float_count) * sizeof(float) +
1805
1917
  ( backend_token_count) * sizeof(llama_token);
1806
1918
 
1807
1919
  // alloc only when more than the current capacity is required
@@ -1816,8 +1928,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
1816
1928
 
1817
1929
  // TODO: not needed?
1818
1930
  buf_output = nullptr;
1819
- logits = nullptr;
1820
- embd = nullptr;
1931
+ logits.data = nullptr;
1932
+ embd.data = nullptr;
1821
1933
  }
1822
1934
 
1823
1935
  auto * buft = ggml_backend_cpu_buffer_type();
@@ -1836,35 +1948,27 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
1836
1948
 
1837
1949
  float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
1838
1950
 
1839
- logits = nullptr;
1840
- embd = nullptr;
1841
-
1842
1951
  size_t offset = 0;
1843
1952
  uint8_t * base = (uint8_t *) output_base;
1844
1953
 
1845
- logits = (has_logits && cpu_logits) ? output_base : nullptr;
1846
- offset += logits_size * sizeof(float);
1954
+ logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0};
1955
+ offset += logits.size * sizeof(float);
1847
1956
 
1848
- embd = has_embd ? (float *) (base + offset) : nullptr;
1849
- offset += embd_size * sizeof(float);
1850
-
1851
- sampling.logits = nullptr;
1852
- sampling.probs = nullptr;
1853
- sampling.sampled = nullptr;
1854
- sampling.candidates = nullptr;
1957
+ embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
1958
+ offset += embd.size * sizeof(float);
1855
1959
 
1856
1960
  if (has_sampling) {
1857
- sampling.logits = (float *) (base + offset);
1858
- offset += sampling.logits_size * sizeof(float);
1961
+ sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
1962
+ offset += sampling.logits.size * sizeof(float);
1859
1963
 
1860
- sampling.probs = (float *) (base + offset);
1861
- offset += sampling.probs_size * sizeof(float);
1964
+ sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
1965
+ offset += sampling.probs.size * sizeof(float);
1862
1966
 
1863
- sampling.sampled = (llama_token *) (base + offset);
1864
- offset += sampling.sampled_size * sizeof(llama_token);
1967
+ sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
1968
+ offset += sampling.sampled.size * sizeof(llama_token);
1865
1969
 
1866
- sampling.candidates = (llama_token *) (base + offset);
1867
- offset += sampling.candidates_size * sizeof(llama_token);
1970
+ sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
1971
+ offset += sampling.candidates.size * sizeof(llama_token);
1868
1972
 
1869
1973
  // The count vectors keep track of the actual number of logits/probs/candidates
1870
1974
  // copied from the backend for each output row.
@@ -1877,7 +1981,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
1877
1981
  std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
1878
1982
  std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
1879
1983
 
1880
- std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
1984
+ std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
1985
+ } else {
1986
+ sampling.logits = {nullptr, 0};
1987
+ sampling.probs = {nullptr, 0};
1988
+ sampling.sampled = {nullptr, 0};
1989
+ sampling.candidates = {nullptr, 0};
1990
+
1991
+ sampling.logits_count.clear();
1992
+ sampling.probs_count.clear();
1993
+ sampling.candidates_count.clear();
1881
1994
  }
1882
1995
 
1883
1996
  // set all ids as invalid (negative)
@@ -1896,49 +2009,42 @@ void llama_context::output_reorder() {
1896
2009
  const uint64_t i0 = output_swaps[s].i0;
1897
2010
  const uint64_t i1 = output_swaps[s].i1;
1898
2011
 
1899
- if (logits_size > 0) {
2012
+ if (logits.size > 0) {
1900
2013
  for (uint64_t k = 0; k < n_vocab; k++) {
1901
- std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
2014
+ std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]);
1902
2015
  }
1903
2016
  }
1904
2017
 
1905
- if (embd_size > 0) {
2018
+ if (embd.size > 0) {
1906
2019
  for (uint64_t k = 0; k < n_embd; k++) {
1907
- std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
2020
+ std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]);
1908
2021
  }
1909
2022
  }
1910
2023
 
1911
- if (sampling.logits && sampling.logits_size > 0) {
2024
+ if (!sampling.samplers.empty()) {
2025
+ assert(sampling.logits.size > 0);
2026
+ assert(sampling.probs.size > 0);
2027
+ assert(sampling.candidates.size > 0);
2028
+ assert(sampling.sampled.size > 0);
2029
+ assert(sampling.logits_count.size() > 0);
2030
+ assert(sampling.probs_count.size() > 0);
2031
+ assert(sampling.candidates_count.size() > 0);
2032
+
1912
2033
  for (uint64_t k = 0; k < n_vocab; ++k) {
1913
- std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
2034
+ std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
1914
2035
  }
1915
- }
1916
2036
 
1917
- if (sampling.probs && sampling.probs_size > 0) {
1918
2037
  for (uint64_t k = 0; k < n_vocab; ++k) {
1919
- std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
2038
+ std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
1920
2039
  }
1921
- }
1922
2040
 
1923
- if (sampling.candidates && sampling.candidates_size > 0) {
1924
2041
  for (uint64_t k = 0; k < n_vocab; ++k) {
1925
- std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
2042
+ std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
1926
2043
  }
1927
- }
1928
2044
 
1929
- if (sampling.sampled && sampling.sampled_size > 0) {
1930
- std::swap(sampling.sampled[i0], sampling.sampled[i1]);
1931
- }
1932
-
1933
- if (!sampling.logits_count.empty()) {
1934
- std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
1935
- }
1936
-
1937
- if (!sampling.probs_count.empty()) {
1938
- std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
1939
- }
1940
-
1941
- if (!sampling.candidates_count.empty()) {
2045
+ std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
2046
+ std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
2047
+ std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
1942
2048
  std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
1943
2049
  }
1944
2050
  }
@@ -1951,11 +2057,13 @@ void llama_context::output_reorder() {
1951
2057
  //
1952
2058
 
1953
2059
  uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
1954
- if (model.arch == LLM_ARCH_QWEN3NEXT) {
2060
+ if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
1955
2061
  return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
1956
2062
  }
1957
2063
  uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
1958
- res += model.n_lora_nodes;
2064
+ for (const auto & lora : model.loras) {
2065
+ res += lora->get_n_nodes();
2066
+ }
1959
2067
  return res;
1960
2068
  }
1961
2069
 
@@ -1977,7 +2085,7 @@ ggml_cgraph * llama_context::graph_reserve(
1977
2085
 
1978
2086
  ggml_backend_sched_reset(sched.get());
1979
2087
 
1980
- // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
2088
+ // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that
1981
2089
  gf_res_prev->reset();
1982
2090
 
1983
2091
  // store the n_outputs as it is, and restore it afterwards
@@ -2037,8 +2145,8 @@ llm_graph_params llama_context::graph_params(
2037
2145
  /*.gtype =*/ gtype,
2038
2146
  /*.sched =*/ sched.get(),
2039
2147
  /*.backend_cpu =*/ backend_cpu,
2040
- /*.cvec =*/ &cvec,
2041
- /*.loras =*/ &loras,
2148
+ /*.cvec =*/ cvec.get(),
2149
+ /*.loras =*/ loras.get(),
2042
2150
  /*.mctx =*/ mctx,
2043
2151
  /*.cross =*/ &cross,
2044
2152
  /*.samplers =*/ sampling.samplers,
@@ -2085,13 +2193,6 @@ llm_graph_cb llama_context::graph_get_cb() const {
2085
2193
  ggml_set_name(cur, name);
2086
2194
  }
2087
2195
 
2088
- if (!cparams.offload_kqv) {
2089
- if (strcmp(name, "kqv_merged_cont") == 0) {
2090
- // all nodes between the KV store and the attention output are run on the CPU
2091
- ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
2092
- }
2093
- }
2094
-
2095
2196
  // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
2096
2197
  // FIXME: fix in ggml_backend_sched
2097
2198
  const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
@@ -2443,63 +2544,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2443
2544
  // TODO: add more model-specific info which should prevent loading the session file if not identical
2444
2545
  }
2445
2546
 
2446
- // write output ids
2447
- {
2448
- LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2449
-
2450
- const auto n_outputs = this->n_outputs;
2451
- const auto & output_ids = this->output_ids;
2452
-
2453
- std::vector<int32_t> w_output_pos;
2454
-
2455
- w_output_pos.resize(n_outputs);
2456
-
2457
- // build a more compact representation of the output ids
2458
- for (size_t i = 0; i < n_batch(); ++i) {
2459
- // map an output id to a position in the batch
2460
- int64_t pos = output_ids[i];
2461
- if (pos >= 0) {
2462
- GGML_ASSERT(pos < n_outputs);
2463
- w_output_pos[pos] = i;
2464
- }
2465
- }
2466
-
2467
- io.write(&n_outputs, sizeof(n_outputs));
2468
-
2469
- if (n_outputs) {
2470
- io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
2471
- }
2472
- }
2473
-
2474
- // write logits
2475
- {
2476
- LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
2477
-
2478
- const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
2479
-
2480
- io.write(&logits_size, sizeof(logits_size));
2481
-
2482
- if (logits_size) {
2483
- io.write(logits, logits_size * sizeof(float));
2484
- }
2485
- }
2486
-
2487
- // write embeddings
2488
- {
2489
- LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
2490
-
2491
- const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
2492
-
2493
- io.write(&embd_size, sizeof(embd_size));
2494
-
2495
- if (embd_size) {
2496
- io.write(embd, embd_size * sizeof(float));
2497
- }
2498
- }
2499
-
2500
- // TODO: handle sampling buffers and samplers state ?
2501
- // https://github.com/ggml-org/llama.cpp/pull/17004
2502
-
2503
2547
  if (memory != nullptr) {
2504
2548
  LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
2505
2549
  memory->state_write(io);
@@ -2525,73 +2569,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2525
2569
  // TODO: add more info which needs to be identical but which is not verified otherwise
2526
2570
  }
2527
2571
 
2528
- // read output ids
2529
- {
2530
- LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
2531
-
2532
- auto n_outputs = this->n_outputs;
2533
- io.read_to(&n_outputs, sizeof(n_outputs));
2534
-
2535
- // Create a dummy batch for state loading.
2536
- llama_batch dummy_batch = {};
2537
- dummy_batch.n_tokens = 0;
2538
- if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
2539
- throw std::runtime_error("could not reserve outputs");
2540
- }
2541
-
2542
- std::vector<int32_t> output_pos;
2543
-
2544
- if (n_outputs) {
2545
- output_pos.resize(n_outputs);
2546
- io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
2547
-
2548
- for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
2549
- int32_t id = output_pos[i];
2550
- if ((uint32_t) id >= n_batch()) {
2551
- throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
2552
- }
2553
- this->output_ids[id] = i;
2554
- }
2555
-
2556
- this->n_outputs = n_outputs;
2557
- }
2558
- }
2559
-
2560
- // read logits
2561
- {
2562
- LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
2563
-
2564
- uint64_t logits_size;
2565
- io.read_to(&logits_size, sizeof(logits_size));
2566
-
2567
- if (this->logits_size < logits_size) {
2568
- throw std::runtime_error("logits buffer too small");
2569
- }
2570
-
2571
- if (logits_size) {
2572
- io.read_to(this->logits, logits_size * sizeof(float));
2573
- }
2574
- }
2575
-
2576
- // read embeddings
2577
- {
2578
- LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
2579
-
2580
- uint64_t embd_size;
2581
- io.read_to(&embd_size, sizeof(embd_size));
2582
-
2583
- if (this->embd_size < embd_size) {
2584
- throw std::runtime_error("embeddings buffer too small");
2585
- }
2586
-
2587
- if (embd_size) {
2588
- io.read_to(this->embd, embd_size * sizeof(float));
2589
- }
2590
- }
2591
-
2592
- // TODO: handle sampling buffers and samplers state ?
2593
- // https://github.com/ggml-org/llama.cpp/pull/17004
2594
-
2595
2572
  if (memory) {
2596
2573
  LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
2597
2574
 
@@ -2724,6 +2701,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
2724
2701
  llama_set_param(model->cls_b, param_filter, param_filter_ud);
2725
2702
  llama_set_param(model->cls_out, param_filter, param_filter_ud);
2726
2703
  llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
2704
+ llama_set_param(model->cls_norm, param_filter, param_filter_ud);
2727
2705
 
2728
2706
  for (struct llama_layer & layer : model->layers) {
2729
2707
  for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
@@ -2780,7 +2758,7 @@ void llama_context::opt_epoch_iter(
2780
2758
  }
2781
2759
 
2782
2760
  // reserve output buffer
2783
- if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
2761
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
2784
2762
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
2785
2763
  GGML_ABORT("TODO: handle this error");
2786
2764
  };
@@ -2815,7 +2793,7 @@ void llama_context::opt_epoch_iter(
2815
2793
  };
2816
2794
  ctx_compute_opt = ggml_init(params);
2817
2795
  }
2818
- ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2796
+ ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
2819
2797
  ggml_opt_alloc(opt_ctx, train);
2820
2798
 
2821
2799
  res->set_inputs(&ubatch);
@@ -2957,19 +2935,23 @@ llama_context * llama_init_from_model(
2957
2935
 
2958
2936
  if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2959
2937
  const uint32_t blck_size = ggml_blck_size(params.type_k);
2960
- if (model->hparams.n_embd_head_k % blck_size != 0) {
2961
- LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2962
- __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2963
- return nullptr;
2938
+ for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
2939
+ if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
2940
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2941
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
2942
+ return nullptr;
2943
+ }
2964
2944
  }
2965
2945
  }
2966
2946
 
2967
2947
  if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2968
2948
  const uint32_t blck_size = ggml_blck_size(params.type_v);
2969
- if (model->hparams.n_embd_head_v % blck_size != 0) {
2970
- LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2971
- __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2972
- return nullptr;
2949
+ for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
2950
+ if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
2951
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
2952
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
2953
+ return nullptr;
2954
+ }
2973
2955
  }
2974
2956
  }
2975
2957
 
@@ -3161,37 +3143,43 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
3161
3143
  return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
3162
3144
  }
3163
3145
 
3164
- // llama adapter API
3165
-
3166
- int32_t llama_set_adapter_lora(
3167
- llama_context * ctx,
3168
- llama_adapter_lora * adapter,
3169
- float scale) {
3170
- ctx->set_adapter_lora(adapter, scale);
3171
-
3172
- return 0;
3146
+ struct ggml_cgraph * llama_graph_reserve(
3147
+ struct llama_context * ctx,
3148
+ uint32_t n_tokens,
3149
+ uint32_t n_seqs,
3150
+ uint32_t n_outputs) {
3151
+ auto * memory = ctx->get_memory();
3152
+ llama_memory_context_ptr mctx;
3153
+ if (memory) {
3154
+ mctx = memory->init_full();
3155
+ }
3156
+ return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get());
3173
3157
  }
3174
3158
 
3175
- int32_t llama_rm_adapter_lora(
3159
+ // llama adapter API
3160
+
3161
+ int32_t llama_set_adapters_lora(
3176
3162
  llama_context * ctx,
3177
- llama_adapter_lora * adapter) {
3178
- bool res = ctx->rm_adapter_lora(adapter);
3163
+ llama_adapter_lora ** adapters,
3164
+ size_t n_adapters,
3165
+ float * scales) {
3166
+ if (adapters == nullptr || scales == nullptr) {
3167
+ GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call");
3168
+ }
3179
3169
 
3180
- return res ? 0 : -1;
3181
- }
3170
+ ctx->set_adapters_lora(adapters, n_adapters, scales);
3182
3171
 
3183
- void llama_clear_adapter_lora(llama_context * ctx) {
3184
- ctx->clear_adapter_lora();
3172
+ return 0;
3185
3173
  }
3186
3174
 
3187
- int32_t llama_apply_adapter_cvec(
3175
+ int32_t llama_set_adapter_cvec(
3188
3176
  llama_context * ctx,
3189
- const float * data,
3190
- size_t len,
3191
- int32_t n_embd,
3192
- int32_t il_start,
3193
- int32_t il_end) {
3194
- bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
3177
+ const float * data,
3178
+ size_t len,
3179
+ int32_t n_embd,
3180
+ int32_t il_start,
3181
+ int32_t il_end) {
3182
+ bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end);
3195
3183
 
3196
3184
  return res ? 0 : -1;
3197
3185
  }