whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -7,13 +7,51 @@
7
7
  #include "llama-kv-cache.h"
8
8
  #include "llama-kv-cache-iswa.h"
9
9
  #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-hybrid-iswa.h"
10
11
  #include "llama-memory-recurrent.h"
11
12
 
12
13
  #include <cassert>
13
14
  #include <cmath>
14
15
  #include <cstring>
16
+ #include <numeric>
17
+ #include <sstream>
15
18
  #include <unordered_set>
16
19
 
20
+ // dedup helpers
21
+
22
+ static ggml_tensor * build_kq_mask(
23
+ ggml_context * ctx,
24
+ const llama_kv_cache_context * mctx,
25
+ const llama_ubatch & ubatch,
26
+ const llama_cparams & cparams) {
27
+ const auto n_kv = mctx->get_n_kv();
28
+ const auto n_tokens = ubatch.n_tokens;
29
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
30
+
31
+ return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
32
+ }
33
+
34
+ static bool can_reuse_kq_mask(
35
+ ggml_tensor * kq_mask,
36
+ const llama_kv_cache_context * mctx,
37
+ const llama_ubatch & ubatch,
38
+ const llama_cparams & cparams) {
39
+ const auto n_kv = mctx->get_n_kv();
40
+ const auto n_tokens = ubatch.n_tokens;
41
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
42
+
43
+ bool res = true;
44
+
45
+ res &= (kq_mask->ne[0] == n_kv);
46
+ res &= (kq_mask->ne[1] == n_tokens/n_stream);
47
+ res &= (kq_mask->ne[2] == 1);
48
+ res &= (kq_mask->ne[3] == n_stream);
49
+
50
+ return res;
51
+ }
52
+
53
+ // impl
54
+
17
55
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
18
56
  if (ubatch->token) {
19
57
  const int64_t n_tokens = ubatch->n_tokens;
@@ -22,7 +60,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
22
60
  }
23
61
 
24
62
  if (ubatch->embd) {
25
- const int64_t n_embd = embd->ne[0];
63
+ GGML_ASSERT(n_embd == embd->ne[0]);
64
+
26
65
  const int64_t n_tokens = ubatch->n_tokens;
27
66
 
28
67
  ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
@@ -32,8 +71,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
32
71
  bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
33
72
  bool res = true;
34
73
 
35
- res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
36
- res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
74
+ res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
75
+ res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
37
76
 
38
77
  return res;
39
78
  }
@@ -96,11 +135,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
96
135
 
97
136
  int32_t * data = (int32_t *) pos_bucket->data;
98
137
 
99
- for (int h = 0; h < 1; ++h) {
100
- for (int j = 0; j < n_tokens; ++j) {
101
- for (int i = 0; i < n_tokens; ++i) {
102
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
103
- }
138
+ for (int j = 0; j < n_tokens; ++j) {
139
+ for (int i = 0; i < n_tokens; ++i) {
140
+ data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
104
141
  }
105
142
  }
106
143
  }
@@ -148,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
148
185
  }
149
186
 
150
187
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
151
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
188
+ if (cparams.embeddings &&
189
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
190
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
191
+
152
192
  const int64_t n_tokens = ubatch->n_tokens;
153
193
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
154
194
  const int64_t n_seqs_unq = ubatch->n_seqs_unq;
@@ -210,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
210
250
 
211
251
  const bool last = (
212
252
  cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
213
- (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
253
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
214
254
  );
215
255
 
216
256
  for (int i = 0; i < n_tokens; ++i) {
@@ -323,34 +363,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
323
363
  const int64_t n_tokens = ubatch->n_tokens;
324
364
 
325
365
  const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
326
- for (int h = 0; h < 1; ++h) {
327
- for (int i1 = 0; i1 < n_tokens; ++i1) {
328
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
329
- const llama_pos p1 = ubatch->pos[i1];
366
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
367
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
368
+ const llama_pos p1 = ubatch->pos[i1];
330
369
 
331
- const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
370
+ const uint64_t idst = i1*n_kv;
332
371
 
333
- for (int i0 = 0; i0 < n_tokens; ++i0) {
334
- const llama_seq_id s0 = ubatch->seq_id[i0][0];
335
- const llama_pos p0 = ubatch->pos[i0];
372
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
373
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
374
+ const llama_pos p0 = ubatch->pos[i0];
336
375
 
337
- // mask different sequences
338
- if (s0 != s1) {
339
- continue;
340
- }
341
-
342
- // mask future tokens
343
- if (cparams.causal_attn && p0 > p1) {
344
- continue;
345
- }
376
+ // mask different sequences
377
+ if (s0 != s1) {
378
+ continue;
379
+ }
346
380
 
347
- // apply SWA if any
348
- if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
349
- continue;
350
- }
381
+ // mask future tokens
382
+ if (cparams.causal_attn && p0 > p1) {
383
+ continue;
384
+ }
351
385
 
352
- data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
386
+ // apply SWA if any
387
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
388
+ continue;
353
389
  }
390
+
391
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
354
392
  }
355
393
  }
356
394
  };
@@ -403,8 +441,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
403
441
  res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
404
442
  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
405
443
 
406
- res &= self_kq_mask->ne[0] == mctx->get_n_kv();
407
- res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
444
+ res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
445
+
446
+ return res;
447
+ }
448
+
449
+ void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
450
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
451
+
452
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
453
+ }
454
+
455
+ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
456
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
457
+
458
+ this->mctx = mctx;
459
+
460
+ bool res = true;
461
+
462
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
463
+
464
+ res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
408
465
 
409
466
  return res;
410
467
  }
@@ -434,11 +491,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
434
491
  res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
435
492
  //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
436
493
 
437
- res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
438
- res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
439
-
440
- res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
441
- res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
494
+ res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
495
+ res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
442
496
 
443
497
  return res;
444
498
  }
@@ -454,27 +508,20 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
454
508
 
455
509
  float * data = (float *) cross_kq_mask->data;
456
510
 
457
- for (int h = 0; h < 1; ++h) {
458
- for (int i = 0; i < n_tokens; ++i) {
459
- for (int j = 0; j < n_enc; ++j) {
460
- float f = -INFINITY;
511
+ for (int i = 0; i < n_tokens; ++i) {
512
+ GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
513
+ for (int j = 0; j < n_enc; ++j) {
514
+ float f = -INFINITY;
461
515
 
462
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
463
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
516
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
517
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
464
518
 
465
- if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
466
- f = 0.0f;
467
- }
519
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
520
+ f = 0.0f;
468
521
  }
469
-
470
- data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
471
522
  }
472
- }
473
523
 
474
- for (int i = n_tokens; i < n_tokens; ++i) {
475
- for (int j = 0; j < n_enc; ++j) {
476
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
477
- }
524
+ data[i*n_enc + j] = f;
478
525
  }
479
526
  }
480
527
  }
@@ -508,8 +555,118 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
508
555
  res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
509
556
  //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
510
557
 
511
- res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
512
- res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
558
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
559
+
560
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
561
+
562
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
563
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
564
+
565
+ res &= inp_rs->head == mctx->get_recr()->get_head();
566
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
567
+
568
+ return res;
569
+ }
570
+
571
+ // TODO: Hybrid input classes are a bit redundant.
572
+ // Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
573
+ // Refactoring is required in the future.
574
+ void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
575
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
576
+
577
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
578
+
579
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
580
+
581
+ if (inp_rs->s_copy) {
582
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
583
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
584
+
585
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
586
+ for (uint32_t i = 0; i < n_rs; ++i) {
587
+ data[i] = mctx->get_recr()->s_copy(i);
588
+ }
589
+ }
590
+ }
591
+
592
+ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
593
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
594
+
595
+ this->mctx = mctx;
596
+
597
+ bool res = true;
598
+
599
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
600
+
601
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
602
+
603
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
604
+
605
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
606
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
607
+
608
+ res &= inp_rs->head == mctx->get_recr()->get_head();
609
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
610
+
611
+ return res;
612
+ }
613
+
614
+ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
615
+ const auto * attn_ctx = mctx->get_attn();
616
+
617
+ // base tensors may not be allocated if there are no non-SWA attention layers
618
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
619
+ attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
620
+ attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
621
+
622
+ attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
623
+ }
624
+
625
+ // swa tensors may not be allocated if there are no SWA attention layers
626
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
627
+ attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
628
+ attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
629
+
630
+ attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
631
+ }
632
+
633
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
634
+
635
+ if (inp_rs->s_copy) {
636
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
637
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
638
+
639
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
640
+ for (uint32_t i = 0; i < n_rs; ++i) {
641
+ data[i] = mctx->get_recr()->s_copy(i);
642
+ }
643
+ }
644
+ }
645
+
646
+ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
647
+ const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
648
+
649
+ this->mctx = mctx;
650
+
651
+ bool res = true;
652
+
653
+ const auto * attn_ctx = mctx->get_attn();
654
+
655
+ // base tensors may not be allocated if there are no non-SWA attention layers
656
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
657
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
658
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
659
+
660
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
661
+ }
662
+
663
+ // swa tensors may not be allocated if there are no SWA attention layers
664
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
665
+ res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
666
+ //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
667
+
668
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
669
+ }
513
670
 
514
671
  res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
515
672
 
@@ -575,7 +732,8 @@ int64_t llm_graph_result::get_max_nodes() const {
575
732
  }
576
733
 
577
734
  void llm_graph_result::reset() {
578
- t_tokens = nullptr;
735
+ t_inp_tokens = nullptr;
736
+ t_inp_embd = nullptr;
579
737
  t_logits = nullptr;
580
738
  t_embd = nullptr;
581
739
  t_embd_pooled = nullptr;
@@ -691,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
691
849
  ubatch (params.ubatch),
692
850
  n_embd (hparams.n_embd),
693
851
  n_layer (hparams.n_layer),
694
- n_rot (hparams.n_rot),
852
+ n_rot (hparams.n_rot()),
695
853
  n_ctx (cparams.n_ctx),
696
854
  n_head (hparams.n_head()),
697
855
  n_head_kv (hparams.n_head_kv()),
698
- n_embd_head_k (hparams.n_embd_head_k),
856
+ n_embd_head_k (hparams.n_embd_head_k()),
699
857
  n_embd_k_gqa (hparams.n_embd_k_gqa()),
700
- n_embd_head_v (hparams.n_embd_head_v),
858
+ n_embd_head_v (hparams.n_embd_head_v()),
701
859
  n_embd_v_gqa (hparams.n_embd_v_gqa()),
702
860
  n_expert (hparams.n_expert),
703
861
  n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
@@ -742,7 +900,8 @@ ggml_tensor * llm_graph_context::build_cvec(
742
900
 
743
901
  ggml_tensor * llm_graph_context::build_lora_mm(
744
902
  ggml_tensor * w,
745
- ggml_tensor * cur) const {
903
+ ggml_tensor * cur,
904
+ ggml_tensor * w_s) const {
746
905
  ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
747
906
 
748
907
  for (const auto & lora : *loras) {
@@ -763,6 +922,10 @@ ggml_tensor * llm_graph_context::build_lora_mm(
763
922
  res = ggml_add(ctx0, res, ab_cur);
764
923
  }
765
924
 
925
+ if (w_s) {
926
+ res = ggml_mul(ctx0, res, w_s);
927
+ }
928
+
766
929
  return res;
767
930
  }
768
931
 
@@ -888,6 +1051,26 @@ ggml_tensor * llm_graph_context::build_ffn(
888
1051
  switch (type_op) {
889
1052
  case LLM_FFN_SILU:
890
1053
  if (gate && type_gate == LLM_FFN_PAR) {
1054
+ // Step35: HF clamps gate (after SiLU) and up before multiplication
1055
+ if (arch == LLM_ARCH_STEP35 && il >= 0) {
1056
+ const float limit = hparams.swiglu_clamp_shexp[il];
1057
+ constexpr float eps = 1e-6f;
1058
+ if (limit > eps) {
1059
+ ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1060
+ cb(gate_act, "ffn_silu", il);
1061
+ gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1062
+ cb(gate_act, "ffn_silu_clamped", il);
1063
+
1064
+ tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1065
+ cb(tmp, "ffn_up_clamped", il);
1066
+
1067
+ cur = ggml_mul(ctx0, gate_act, tmp);
1068
+ cb(cur, "ffn_swiglu_limited", il);
1069
+ type_gate = LLM_FFN_SEQ;
1070
+ break;
1071
+ }
1072
+ }
1073
+
891
1074
  cur = ggml_swiglu_split(ctx0, cur, tmp);
892
1075
  cb(cur, "ffn_swiglu", il);
893
1076
  type_gate = LLM_FFN_SEQ;
@@ -951,8 +1134,8 @@ ggml_tensor * llm_graph_context::build_ffn(
951
1134
 
952
1135
  if (down) {
953
1136
  cur = build_lora_mm(down, cur);
954
- if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
955
- // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1137
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
1138
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
956
1139
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
957
1140
  }
958
1141
  }
@@ -984,11 +1167,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
984
1167
  int64_t n_expert_used,
985
1168
  llm_ffn_op_type type_op,
986
1169
  bool norm_w,
987
- bool scale_w,
988
1170
  float w_scale,
989
1171
  llama_expert_gating_func_type gating_op,
990
1172
  int il,
991
- ggml_tensor * probs_in) const {
1173
+ ggml_tensor * probs_in,
1174
+ ggml_tensor * gate_up_exps,
1175
+ ggml_tensor * up_exps_s,
1176
+ ggml_tensor * gate_exps_s,
1177
+ ggml_tensor * down_exps_s) const {
992
1178
  return build_moe_ffn(
993
1179
  cur,
994
1180
  gate_inp, /* gate_inp_b */ nullptr,
@@ -1000,11 +1186,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1000
1186
  n_expert_used,
1001
1187
  type_op,
1002
1188
  norm_w,
1003
- scale_w,
1004
1189
  w_scale,
1005
1190
  gating_op,
1006
1191
  il,
1007
- probs_in
1192
+ probs_in,
1193
+ gate_up_exps,
1194
+ /* gate_up_exps_b */ nullptr,
1195
+ up_exps_s,
1196
+ gate_exps_s,
1197
+ down_exps_s
1008
1198
  );
1009
1199
  }
1010
1200
 
@@ -1023,11 +1213,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1023
1213
  int64_t n_expert_used,
1024
1214
  llm_ffn_op_type type_op,
1025
1215
  bool norm_w,
1026
- bool scale_w,
1027
1216
  float w_scale,
1028
1217
  llama_expert_gating_func_type gating_op,
1029
1218
  int il,
1030
- ggml_tensor * probs_in) const {
1219
+ ggml_tensor * probs_in,
1220
+ ggml_tensor * gate_up_exps,
1221
+ ggml_tensor * gate_up_exps_b,
1222
+ ggml_tensor * up_exps_s,
1223
+ ggml_tensor * gate_exps_s,
1224
+ ggml_tensor * down_exps_s) const {
1031
1225
  const int64_t n_embd = cur->ne[0];
1032
1226
  const int64_t n_tokens = cur->ne[1];
1033
1227
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -1149,7 +1343,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1149
1343
 
1150
1344
  weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1151
1345
  }
1152
- if (scale_w) {
1346
+ if (w_scale != 0.0f && w_scale != 1.0f) {
1153
1347
  weights = ggml_scale(ctx0, weights, w_scale);
1154
1348
  cb(weights, "ffn_moe_weights_scaled", il);
1155
1349
  }
@@ -1166,30 +1360,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1166
1360
  cb(cur, "ffn_moe_weighted", il);
1167
1361
  }
1168
1362
 
1169
- ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1170
- cb(up, "ffn_moe_up", il);
1363
+ ggml_tensor * up = nullptr;
1364
+ ggml_tensor * experts = nullptr;
1171
1365
 
1172
- if (up_exps_b) {
1173
- up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1174
- cb(up, "ffn_moe_up_biased", il);
1175
- }
1366
+ if (gate_up_exps) {
1367
+ // merged gate_up path: one mul_mat_id, then split into gate and up views
1368
+ ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
1369
+ cb(gate_up, "ffn_moe_gate_up", il);
1176
1370
 
1177
- ggml_tensor * experts = nullptr;
1178
- if (gate_exps) {
1179
- cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1371
+ if (gate_up_exps_b) {
1372
+ gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
1373
+ cb(gate_up, "ffn_moe_gate_up_biased", il);
1374
+ }
1375
+
1376
+ // apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
1377
+ if (up_exps_s) {
1378
+ ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1379
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1380
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1381
+ gate_up = ggml_mul(ctx0, gate_up, s);
1382
+ cb(gate_up, "ffn_moe_gate_up_scaled", il);
1383
+ }
1384
+
1385
+ const int64_t n_ff = gate_up->ne[0] / 2;
1386
+ cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
1180
1387
  cb(cur, "ffn_moe_gate", il);
1388
+ up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
1389
+ cb(up, "ffn_moe_up", il);
1181
1390
  } else {
1182
- cur = up;
1183
- }
1391
+ // separate gate and up path
1392
+ up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1393
+ cb(up, "ffn_moe_up", il);
1394
+
1395
+ if (up_exps_b) {
1396
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1397
+ cb(up, "ffn_moe_up_biased", il);
1398
+ }
1399
+
1400
+ // apply per-expert scale2 to up
1401
+ if (up_exps_s) {
1402
+ ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1403
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1404
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1405
+ up = ggml_mul(ctx0, up, s);
1406
+ cb(up, "ffn_moe_up_scaled", il);
1407
+ }
1408
+
1409
+ if (gate_exps) {
1410
+ cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1411
+ cb(cur, "ffn_moe_gate", il);
1412
+ } else {
1413
+ cur = up;
1414
+ }
1415
+
1416
+ if (gate_exps_b) {
1417
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1418
+ cb(cur, "ffn_moe_gate_biased", il);
1419
+ }
1184
1420
 
1185
- if (gate_exps_b) {
1186
- cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1187
- cb(cur, "ffn_moe_gate_biased", il);
1421
+ // apply per-expert scale2 to gate
1422
+ if (gate_exps_s) {
1423
+ ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
1424
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1425
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1426
+ cur = ggml_mul(ctx0, cur, s);
1427
+ cb(cur, "ffn_moe_gate_scaled", il);
1428
+ }
1188
1429
  }
1189
1430
 
1431
+ const bool has_gate = gate_exps || gate_up_exps;
1432
+
1190
1433
  switch (type_op) {
1191
1434
  case LLM_FFN_SILU:
1192
1435
  if (gate_exps) {
1436
+ // Step35: per-layer clamp for routed experts
1437
+ if (arch == LLM_ARCH_STEP35 && il >= 0) {
1438
+ const float limit = hparams.swiglu_clamp_exp[il];
1439
+ constexpr float eps = 1e-6f;
1440
+ if (limit > eps) {
1441
+ ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1442
+ cb(gate_act, "ffn_moe_silu", il);
1443
+ gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1444
+ cb(gate_act, "ffn_moe_silu_clamped", il);
1445
+
1446
+ up = ggml_clamp(ctx0, up, -limit, limit);
1447
+ cb(up, "ffn_moe_up_clamped", il);
1448
+
1449
+ cur = ggml_mul(ctx0, gate_act, up);
1450
+ cb(cur, "ffn_moe_swiglu_limited", il);
1451
+ break;
1452
+ }
1453
+ }
1454
+ }
1455
+
1456
+ if (has_gate) {
1193
1457
  cur = ggml_swiglu_split(ctx0, cur, up);
1194
1458
  cb(cur, "ffn_moe_swiglu", il);
1195
1459
  } else {
@@ -1197,7 +1461,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1197
1461
  cb(cur, "ffn_moe_silu", il);
1198
1462
  } break;
1199
1463
  case LLM_FFN_GELU:
1200
- if (gate_exps) {
1464
+ if (has_gate) {
1201
1465
  cur = ggml_geglu_split(ctx0, cur, up);
1202
1466
  cb(cur, "ffn_moe_geglu", il);
1203
1467
  } else {
@@ -1213,7 +1477,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1213
1477
  cb(cur, "ffn_moe_swiglu_oai", il);
1214
1478
  } break;
1215
1479
  case LLM_FFN_RELU:
1216
- if (gate_exps) {
1480
+ if (has_gate) {
1217
1481
  cur = ggml_reglu_split(ctx0, cur, up);
1218
1482
  cb(cur, "ffn_moe_reglu", il);
1219
1483
  } else {
@@ -1221,7 +1485,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1221
1485
  cb(cur, "ffn_moe_relu", il);
1222
1486
  } break;
1223
1487
  case LLM_FFN_RELU_SQR:
1224
- if (gate_exps) {
1488
+ if (has_gate) {
1225
1489
  // TODO: add support for gated squared relu
1226
1490
  GGML_ABORT("fatal error: gated squared relu not implemented");
1227
1491
  } else {
@@ -1241,6 +1505,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1241
1505
  cb(experts, "ffn_moe_down_biased", il);
1242
1506
  }
1243
1507
 
1508
+ // apply per-expert scale2 to down
1509
+ if (down_exps_s) {
1510
+ ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
1511
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1512
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1513
+ experts = ggml_mul(ctx0, experts, s);
1514
+ cb(experts, "ffn_moe_down_scaled", il);
1515
+ }
1516
+
1244
1517
  if (!weight_before_ffn) {
1245
1518
  experts = ggml_mul(ctx0, experts, weights);
1246
1519
  cb(cur, "ffn_moe_weighted", il);
@@ -1279,17 +1552,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1279
1552
 
1280
1553
  // input embeddings with optional lora
1281
1554
  ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1282
- const int64_t n_embd = hparams.n_embd_inp();
1555
+ const int64_t n_embd_inp = hparams.n_embd_inp();
1556
+ const int64_t n_embd = hparams.n_embd;
1283
1557
 
1284
- auto inp = std::make_unique<llm_graph_input_embd>();
1558
+ assert(n_embd_inp >= n_embd);
1285
1559
 
1286
- ggml_tensor * cur = nullptr;
1560
+ auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1287
1561
 
1288
- if (ubatch.token) {
1289
- inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1290
- //cb(inp->tokens, "inp_tokens", -1);
1291
- ggml_set_input(inp->tokens);
1292
- res->t_tokens = inp->tokens;
1562
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1563
+ cb(inp->tokens, "inp_tokens", -1);
1564
+ ggml_set_input(inp->tokens);
1565
+ res->t_inp_tokens = inp->tokens;
1566
+
1567
+ inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1568
+ cb(inp->embd, "inp_embd", -1);
1569
+ ggml_set_input(inp->embd);
1570
+
1571
+ // select one of the 2 inputs, based on the batch contents
1572
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1573
+ std::array<ggml_tensor *, 2> inps;
1574
+
1575
+ // token embeddings path (ubatch.token != nullptr)
1576
+ {
1577
+ auto & cur = inps[0];
1293
1578
 
1294
1579
  cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1295
1580
 
@@ -1310,19 +1595,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1310
1595
 
1311
1596
  cur = ggml_add(ctx0, cur, inpL_delta);
1312
1597
  }
1313
- } else {
1314
- inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
1315
- ggml_set_input(inp->embd);
1598
+
1599
+ if (n_embd_inp != n_embd) {
1600
+ cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1601
+ }
1602
+ }
1603
+
1604
+ // vector embeddings path (ubatch.embd != nullptr)
1605
+ {
1606
+ auto & cur = inps[1];
1316
1607
 
1317
1608
  cur = inp->embd;
1318
1609
  }
1319
1610
 
1611
+ assert(ggml_are_same_shape (inps[0], inps[1]));
1612
+ assert(ggml_are_same_stride(inps[0], inps[1]));
1613
+
1614
+ ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1615
+
1616
+ if (n_embd_inp != n_embd) {
1617
+ cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1618
+ }
1619
+
1620
+ res->t_inp_embd = cur;
1621
+
1320
1622
  // For Granite architecture
1321
1623
  if (hparams.f_embedding_scale != 0.0f) {
1322
1624
  cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1323
1625
  }
1324
1626
 
1325
- cb(cur, "inp_embd", -1);
1627
+ cb(cur, "embd", -1);
1326
1628
 
1327
1629
  res->add_input(std::move(inp));
1328
1630
 
@@ -1354,6 +1656,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1354
1656
  // this need to be 1x1xN for broadcasting
1355
1657
  cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1356
1658
  ggml_set_input(cur);
1659
+ ggml_set_name(cur, "attn_scale");
1357
1660
 
1358
1661
  res->add_input(std::move(inp));
1359
1662
 
@@ -1363,7 +1666,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1363
1666
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1364
1667
  // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1365
1668
  // but this would make the graph topology depend on the number of output tokens, which can interere with
1366
- // features that require constant topology such as pipline parallelism
1669
+ // features that require constant topology such as pipeline parallelism
1367
1670
  // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1368
1671
  //if (n_outputs < n_tokens) {
1369
1672
  // return nullptr;
@@ -1421,7 +1724,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1421
1724
  //}
1422
1725
 
1423
1726
  const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1424
- const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1727
+ const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1425
1728
 
1426
1729
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
1427
1730
  ggml_set_input(cur);
@@ -1499,7 +1802,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1499
1802
 
1500
1803
  ggml_tensor * cur;
1501
1804
 
1502
- if (cparams.flash_attn && kq_b == nullptr) {
1805
+ const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
1806
+ if (use_flash_attn) {
1503
1807
  GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1504
1808
 
1505
1809
  if (v_trans) {
@@ -1525,7 +1829,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1525
1829
  if (v_mla) {
1526
1830
  #if 0
1527
1831
  // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1528
- // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1832
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient.
1529
1833
  cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1530
1834
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1531
1835
  #else
@@ -1695,14 +1999,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1695
1999
  {
1696
2000
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1697
2001
 
1698
- const auto n_kv = mctx_cur->get_n_kv();
1699
- const auto n_tokens = ubatch.n_tokens;
1700
- const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1701
-
1702
2002
  inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1703
2003
  inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1704
2004
 
1705
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2005
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2006
+
1706
2007
  ggml_set_input(inp->self_kq_mask);
1707
2008
 
1708
2009
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1728,9 +2029,11 @@ ggml_tensor * llm_graph_context::build_attn(
1728
2029
  ggml_tensor * v_cur,
1729
2030
  ggml_tensor * kq_b,
1730
2031
  ggml_tensor * sinks,
1731
- ggml_tensor * v_mla,
2032
+ ggml_tensor * v_mla, // TODO: remove
1732
2033
  float kq_scale,
1733
2034
  int il) const {
2035
+ GGML_ASSERT(v_mla == nullptr);
2036
+
1734
2037
  // these nodes are added to the graph together so that they are not reordered
1735
2038
  // by doing so, the number of splits in the graph is reduced
1736
2039
  // expand k later to enable rope fusion which directly writes into k-v cache
@@ -1758,6 +2061,89 @@ ggml_tensor * llm_graph_context::build_attn(
1758
2061
  ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1759
2062
  cb(cur, "kqv_out", il);
1760
2063
 
2064
+ if (wo) {
2065
+ cur = build_lora_mm(wo, cur);
2066
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
2067
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
2068
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2069
+ }
2070
+ }
2071
+
2072
+ if (wo_b) {
2073
+ cur = ggml_add(ctx0, cur, wo_b);
2074
+ }
2075
+
2076
+ return cur;
2077
+ }
2078
+
2079
+ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
2080
+ ggml_context * ctx0,
2081
+ const llama_ubatch & ubatch,
2082
+ const llama_hparams & hparams,
2083
+ const llama_cparams & cparams,
2084
+ const llama_kv_cache_context * mctx_cur) {
2085
+
2086
+ auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
2087
+
2088
+ {
2089
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2090
+
2091
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2092
+
2093
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2094
+ ggml_set_input(inp->self_kq_mask);
2095
+
2096
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2097
+ }
2098
+
2099
+ return inp;
2100
+ }
2101
+
2102
+ llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2103
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2104
+
2105
+ auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2106
+
2107
+ return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2108
+ }
2109
+
2110
+ ggml_tensor * llm_graph_context::build_attn(
2111
+ llm_graph_input_attn_k * inp,
2112
+ ggml_tensor * wo,
2113
+ ggml_tensor * wo_b,
2114
+ ggml_tensor * q_cur,
2115
+ ggml_tensor * k_cur,
2116
+ ggml_tensor * v_cur,
2117
+ ggml_tensor * kq_b,
2118
+ ggml_tensor * sinks,
2119
+ ggml_tensor * v_mla,
2120
+ float kq_scale,
2121
+ int il) const {
2122
+ // these nodes are added to the graph together so that they are not reordered
2123
+ // by doing so, the number of splits in the graph is reduced
2124
+ // expand k later to enable rope fusion which directly writes into k-v cache
2125
+ ggml_build_forward_expand(gf, q_cur);
2126
+ ggml_build_forward_expand(gf, v_cur);
2127
+ ggml_build_forward_expand(gf, k_cur);
2128
+
2129
+ const auto * mctx_cur = inp->mctx;
2130
+
2131
+ // store to KV cache
2132
+ {
2133
+ const auto & k_idxs = inp->get_k_idxs();
2134
+
2135
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2136
+ }
2137
+
2138
+ const auto & kq_mask = inp->get_kq_mask();
2139
+
2140
+ ggml_tensor * q = q_cur;
2141
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2142
+ ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2143
+
2144
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2145
+ cb(cur, "kqv_out", il);
2146
+
1761
2147
  if (wo) {
1762
2148
  cur = build_lora_mm(wo, cur);
1763
2149
  if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
@@ -1903,15 +2289,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
1903
2289
 
1904
2290
  auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1905
2291
 
1906
- const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1907
-
1908
2292
  {
1909
- const auto n_kv = mctx_cur->get_base()->get_n_kv();
1910
-
1911
2293
  inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1912
2294
  inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1913
2295
 
1914
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2296
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
1915
2297
  ggml_set_input(inp->self_kq_mask);
1916
2298
  ggml_set_name(inp->self_kq_mask, "self_kq_mask");
1917
2299
 
@@ -1922,12 +2304,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
1922
2304
  {
1923
2305
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1924
2306
 
1925
- const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1926
-
1927
2307
  inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1928
2308
  inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1929
2309
 
1930
- inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
2310
+ inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
1931
2311
  ggml_set_input(inp->self_kq_mask_swa);
1932
2312
  ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
1933
2313
 
@@ -2068,10 +2448,57 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2068
2448
  return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2069
2449
  }
2070
2450
 
2451
+ llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2452
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2453
+
2454
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2455
+ auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2456
+
2457
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2458
+
2459
+ return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2460
+ }
2461
+
2462
+ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2463
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2464
+
2465
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2466
+
2467
+ // build iswa attention input
2468
+ const auto * attn_ctx = mctx_cur->get_attn();
2469
+
2470
+ auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2471
+
2472
+ {
2473
+ inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2474
+ inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2475
+
2476
+ inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
2477
+ ggml_set_input(inp_attn->self_kq_mask);
2478
+
2479
+ inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
2480
+ }
2481
+
2482
+ {
2483
+ inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2484
+ inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2485
+
2486
+ inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
2487
+ ggml_set_input(inp_attn->self_kq_mask_swa);
2488
+
2489
+ inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
2490
+ }
2491
+
2492
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2493
+
2494
+ return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2495
+ }
2496
+
2071
2497
  void llm_graph_context::build_dense_out(
2072
2498
  ggml_tensor * dense_2,
2499
+ ggml_tensor * dense_2_b,
2073
2500
  ggml_tensor * dense_3) const {
2074
- if (!cparams.embeddings || !(dense_2 || dense_3)) {
2501
+ if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
2075
2502
  return;
2076
2503
  }
2077
2504
  ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
@@ -2080,6 +2507,9 @@ void llm_graph_context::build_dense_out(
2080
2507
  if (dense_2) {
2081
2508
  cur = ggml_mul_mat(ctx0, dense_2, cur);
2082
2509
  }
2510
+ if (dense_2_b) {
2511
+ cur = ggml_add(ctx0, cur, dense_2_b);
2512
+ }
2083
2513
  if (dense_3) {
2084
2514
  cur = ggml_mul_mat(ctx0, dense_3, cur);
2085
2515
  }
@@ -2093,7 +2523,8 @@ void llm_graph_context::build_pooling(
2093
2523
  ggml_tensor * cls,
2094
2524
  ggml_tensor * cls_b,
2095
2525
  ggml_tensor * cls_out,
2096
- ggml_tensor * cls_out_b) const {
2526
+ ggml_tensor * cls_out_b,
2527
+ ggml_tensor * cls_norm) const {
2097
2528
  if (!cparams.embeddings) {
2098
2529
  return;
2099
2530
  }
@@ -2132,8 +2563,15 @@ void llm_graph_context::build_pooling(
2132
2563
  } break;
2133
2564
  case LLAMA_POOLING_TYPE_RANK:
2134
2565
  {
2135
- ggml_tensor * inp_cls = build_inp_cls();
2136
- cur = ggml_get_rows(ctx0, inp, inp_cls);
2566
+ if (arch == LLM_ARCH_MODERN_BERT) {
2567
+ // modern bert gte reranker builds mean first then applies prediction head and classifier
2568
+ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
2569
+ ggml_tensor * inp_mean = build_inp_mean();
2570
+ cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2571
+ } else {
2572
+ ggml_tensor * inp_cls = build_inp_cls();
2573
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
2574
+ }
2137
2575
 
2138
2576
  // classification head
2139
2577
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
@@ -2142,7 +2580,15 @@ void llm_graph_context::build_pooling(
2142
2580
  if (cls_b) {
2143
2581
  cur = ggml_add(ctx0, cur, cls_b);
2144
2582
  }
2145
- cur = ggml_tanh(ctx0, cur);
2583
+ if (arch == LLM_ARCH_MODERN_BERT) {
2584
+ cur = ggml_gelu(ctx0, cur);
2585
+ } else {
2586
+ cur = ggml_tanh(ctx0, cur);
2587
+ }
2588
+ if (cls_norm) {
2589
+ // head norm
2590
+ cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
2591
+ }
2146
2592
  }
2147
2593
 
2148
2594
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
@@ -2157,7 +2603,7 @@ void llm_graph_context::build_pooling(
2157
2603
  }
2158
2604
 
2159
2605
  // softmax for qwen3 reranker
2160
- if (arch == LLM_ARCH_QWEN3) {
2606
+ if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
2161
2607
  cur = ggml_soft_max(ctx0, cur);
2162
2608
  }
2163
2609
  } break;
@@ -2178,6 +2624,9 @@ void llm_graph_context::build_sampling() const {
2178
2624
  return;
2179
2625
  }
2180
2626
 
2627
+ std::array<ggml_tensor *, 2> outs;
2628
+ outs[0] = res->t_logits;
2629
+
2181
2630
  auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2182
2631
  res->add_input(std::move(inp_sampling));
2183
2632
 
@@ -2198,14 +2647,14 @@ void llm_graph_context::build_sampling() const {
2198
2647
  // add a dummy row of logits
2199
2648
  // this trick makes the graph static, regardless of which samplers are activated
2200
2649
  // this is important in order to minimize graph reallocations
2201
- // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
2202
2650
  ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2203
2651
 
2204
2652
  for (const auto & [seq_id, sampler] : samplers) {
2205
2653
  const auto it = seq_to_logit_row.find(seq_id);
2206
2654
 
2207
2655
  // inactive samplers always work on the first row
2208
- const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
2656
+ const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
2657
+ const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
2209
2658
 
2210
2659
  ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2211
2660
  ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
@@ -2222,22 +2671,26 @@ void llm_graph_context::build_sampling() const {
2222
2671
 
2223
2672
  if (data.sampled != nullptr) {
2224
2673
  res->t_sampled[seq_id] = data.sampled;
2225
- ggml_build_forward_expand(gf, data.sampled);
2674
+ outs[1] = data.sampled;
2675
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2226
2676
  }
2227
2677
 
2228
2678
  if (data.probs != nullptr) {
2229
2679
  res->t_sampled_probs[seq_id] = data.probs;
2230
- ggml_build_forward_expand(gf, data.probs);
2680
+ outs[1] = data.probs;
2681
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2231
2682
  }
2232
2683
 
2233
2684
  if (data.logits != nullptr) {
2234
2685
  res->t_sampled_logits[seq_id] = data.logits;
2235
- ggml_build_forward_expand(gf, data.logits);
2686
+ outs[1] = data.logits;
2687
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2236
2688
  }
2237
2689
 
2238
2690
  if (data.candidates != nullptr) {
2239
2691
  res->t_candidates[seq_id] = data.candidates;
2240
- ggml_build_forward_expand(gf, data.candidates);
2692
+ outs[1] = data.candidates;
2693
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2241
2694
  }
2242
2695
  }
2243
2696