whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,9 +1,12 @@
1
- #include "llama-sampling.h"
1
+ #include "llama-sampler.h"
2
2
 
3
3
  #include "llama-impl.h"
4
4
  #include "llama-vocab.h"
5
5
  #include "llama-grammar.h"
6
6
 
7
+ #include "ggml-cpp.h"
8
+
9
+ #include <array>
7
10
  #include <algorithm>
8
11
  #include <cassert>
9
12
  #include <cfloat>
@@ -345,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
345
348
 
346
349
  // llama_sampler API
347
350
 
348
- struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
351
+ struct llama_sampler * llama_sampler_init(
352
+ struct llama_sampler_i * iface,
353
+ llama_sampler_context_t ctx) {
349
354
  return new llama_sampler {
350
355
  /* .iface = */ iface,
351
356
  /* .ctx = */ ctx,
@@ -361,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
361
366
  }
362
367
 
363
368
  void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
369
+ if (!smpl) {
370
+ return;
371
+ }
372
+
364
373
  if (smpl->iface->accept) {
365
374
  smpl->iface->accept(smpl, token);
366
375
  }
367
376
  }
368
377
 
369
378
  void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
379
+ if (!smpl) {
380
+ return;
381
+ }
382
+
370
383
  GGML_ASSERT(smpl->iface->apply);
371
384
  smpl->iface->apply(smpl, cur_p);
372
385
  }
373
386
 
374
387
  void llama_sampler_reset(struct llama_sampler * smpl) {
388
+ if (!smpl) {
389
+ return;
390
+ }
391
+
375
392
  if (smpl->iface->reset) {
376
393
  smpl->iface->reset(smpl);
377
394
  }
378
395
  }
379
396
 
380
397
  struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
398
+ if (!smpl) {
399
+ return nullptr;
400
+ }
401
+
381
402
  if (smpl->iface->clone) {
382
403
  return smpl->iface->clone(smpl);
383
404
  }
@@ -404,37 +425,200 @@ void llama_sampler_free(struct llama_sampler * smpl) {
404
425
  delete smpl;
405
426
  }
406
427
 
407
- llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
408
- const auto * logits = llama_get_logits_ith(ctx, idx);
428
+ // empty sampler
409
429
 
410
- const llama_model * model = llama_get_model(ctx);
411
- const llama_vocab * vocab = llama_model_get_vocab(model);
430
+ struct llama_sampler_empty {
431
+ const char * name;
432
+ };
412
433
 
413
- const int n_vocab = llama_vocab_n_tokens(vocab);
434
+ static struct llama_sampler * llama_sampler_init_empty(const char * name);
435
+
436
+ static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
437
+ auto * ctx = (llama_sampler_empty *) smpl->ctx;
438
+ return ctx->name;
439
+ }
440
+
441
+ static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
442
+ GGML_UNUSED(smpl);
443
+ GGML_UNUSED(token);
444
+ }
445
+
446
+ static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
447
+ GGML_UNUSED(smpl);
448
+ GGML_UNUSED(cur_p);
449
+ }
450
+
451
+ static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
452
+ GGML_UNUSED(smpl);
453
+ }
454
+
455
+ static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
456
+ auto * ctx = (llama_sampler_empty *) smpl->ctx;
457
+ return llama_sampler_init_empty(ctx->name);
458
+ }
459
+
460
+ static void llama_sampler_empty_free(struct llama_sampler * smpl) {
461
+ delete (llama_sampler_empty *) smpl->ctx;
462
+ }
414
463
 
415
- // TODO: do not allocate each time
416
- std::vector<llama_token_data> cur;
417
- cur.reserve(n_vocab);
418
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
419
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
464
+ static bool llama_sampler_empty_backend_init(
465
+ struct llama_sampler * smpl,
466
+ ggml_backend_buffer_type_t buft) {
467
+ GGML_UNUSED(smpl);
468
+ GGML_UNUSED(buft);
469
+
470
+ return true;
471
+ }
472
+
473
+ static void llama_sampler_empty_backend_accept(
474
+ struct llama_sampler * smpl,
475
+ ggml_context * ctx,
476
+ ggml_cgraph * gf,
477
+ struct ggml_tensor * selected_token) {
478
+ GGML_UNUSED(smpl);
479
+ GGML_UNUSED(ctx);
480
+ GGML_UNUSED(gf);
481
+ GGML_UNUSED(selected_token);
482
+ }
483
+
484
+ static void llama_sampler_empty_backend_apply(
485
+ struct llama_sampler * smpl,
486
+ struct ggml_context * ctx,
487
+ struct ggml_cgraph * gf,
488
+ struct llama_sampler_data * data) {
489
+ GGML_UNUSED(smpl);
490
+ GGML_UNUSED(ctx);
491
+ GGML_UNUSED(gf);
492
+ GGML_UNUSED(data);
493
+ }
494
+
495
+ static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
496
+ GGML_UNUSED(smpl);
497
+ }
498
+
499
+ static struct llama_sampler_i llama_sampler_empty_i = {
500
+ /* .name = */ llama_sampler_empty_name,
501
+ /* .accept = */ llama_sampler_empty_accept,
502
+ /* .apply = */ llama_sampler_empty_apply,
503
+ /* .reset = */ llama_sampler_empty_reset,
504
+ /* .clone = */ llama_sampler_empty_clone,
505
+ /* .free = */ llama_sampler_empty_free,
506
+ /* .backend_init = */ llama_sampler_empty_backend_init,
507
+ /* .backend_accept = */ llama_sampler_empty_backend_accept,
508
+ /* .backend_apply = */ llama_sampler_empty_backend_apply,
509
+ /* .backend_set_input = */ llama_sampler_empty_backend_set_input,
510
+ };
511
+
512
+ struct llama_sampler * llama_sampler_init_empty(const char * name) {
513
+ return llama_sampler_init(
514
+ /* .iface = */ &llama_sampler_empty_i,
515
+ /* .ctx = */ new llama_sampler_empty {
516
+ /* .name = */ name,
517
+ }
518
+ );
519
+ }
520
+
521
+ // common backend sampler functionality
522
+ //
523
+ // +name : means that the sampler is support and will run on the backend
524
+ // -name : means that a ggml operator is not supported by the backend
525
+ //
526
+ struct llama_sampler_backend {
527
+ llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
528
+
529
+ const char * get_name() {
530
+ if (!is_init) {
531
+ return name.c_str();
532
+ }
533
+
534
+ if (support) {
535
+ name_ext = "+" + name;
536
+ } else {
537
+ name_ext = "-" + name;
538
+ }
539
+
540
+ return name_ext.c_str();
420
541
  }
421
542
 
422
- llama_token_data_array cur_p = {
423
- /* .data = */ cur.data(),
424
- /* .size = */ cur.size(),
425
- /* .selected = */ -1,
426
- /* .sorted = */ false,
543
+ void init(bool support) {
544
+ GGML_ASSERT(this->is_init == false);
545
+
546
+ this->is_init = true;
547
+ this->support = support;
548
+ }
549
+
550
+ private:
551
+ std::string name;
552
+ std::string name_ext;
553
+
554
+ bool is_init;
555
+ bool support;
556
+ };
557
+
558
+ // check if all ggml ops used by the sampler are supported by the backend
559
+ static bool llama_sampler_backend_support(
560
+ llama_sampler * smpl,
561
+ ggml_backend_buffer_type_t buft) {
562
+ auto * device = ggml_backend_buft_get_device(buft);
563
+ if (!device) {
564
+ // CPU backend always supported
565
+ return true;
566
+ }
567
+
568
+ ggml_init_params params = {
569
+ /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
570
+ /*.mem_buffer =*/ NULL,
571
+ /*.no_alloc =*/ true,
427
572
  };
428
573
 
429
- llama_sampler_apply(smpl, &cur_p);
574
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
575
+ if (!ctx_ptr) {
576
+ throw std::runtime_error(format("failed to create ggml context"));
577
+ }
430
578
 
431
- GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
579
+ ggml_context * ctx = ctx_ptr.get();
432
580
 
433
- auto token = cur_p.data[cur_p.selected].id;
581
+ const int64_t n = 1024*1024;
434
582
 
435
- llama_sampler_accept(smpl, token);
583
+ llama_sampler_data data = {
584
+ /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
585
+ /*.probs = */ nullptr,
586
+ /*.sampled = */ nullptr,
587
+ /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
588
+ };
436
589
 
437
- return token;
590
+ ggml_cgraph * gf = ggml_new_graph(ctx);
591
+
592
+ smpl->iface->backend_apply(smpl, ctx, gf, &data);
593
+
594
+ if (data.logits) {
595
+ ggml_build_forward_expand(gf, data.logits);
596
+ }
597
+
598
+ if (data.probs) {
599
+ ggml_build_forward_expand(gf, data.probs);
600
+ }
601
+
602
+ if (data.sampled) {
603
+ ggml_build_forward_expand(gf, data.sampled);
604
+ }
605
+
606
+ if (data.candidates) {
607
+ ggml_build_forward_expand(gf, data.candidates);
608
+ }
609
+
610
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
611
+ struct ggml_tensor * op = ggml_graph_node(gf, i);
612
+
613
+ if (!ggml_backend_dev_supports_op(device, op)) {
614
+ LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
615
+ __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
616
+
617
+ return false;
618
+ }
619
+ }
620
+
621
+ return true;
438
622
  }
439
623
 
440
624
  // sampler chain
@@ -448,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
448
632
 
449
633
  time_meas tm(chain->t_sample_us, chain->params.no_perf);
450
634
 
451
- for (auto * smpl : chain->samplers) {
452
- llama_sampler_accept(smpl, token);
635
+ for (auto & smpl : chain->samplers) {
636
+ llama_sampler_accept(smpl.ptr, token);
453
637
  }
454
638
 
455
639
  chain->n_sample++;
@@ -460,20 +644,29 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
460
644
 
461
645
  time_meas tm(chain->t_sample_us, chain->params.no_perf);
462
646
 
463
- for (auto * smpl : chain->samplers) {
464
- llama_sampler_apply(smpl, cur_p);
647
+ bool is_backend = chain->is_init;
648
+
649
+ for (auto & smpl : chain->samplers) {
650
+ if (is_backend && smpl.is_backend) {
651
+ continue;
652
+ }
653
+
654
+ is_backend = false;
655
+
656
+ if (smpl.ptr->iface->apply == nullptr) {
657
+ continue;
658
+ }
659
+
660
+ llama_sampler_apply(smpl.ptr, cur_p);
465
661
  }
466
662
  }
467
663
 
468
664
  static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
469
665
  auto * chain = (llama_sampler_chain *) smpl->ctx;
470
666
 
471
- for (auto * smpl : chain->samplers) {
472
- llama_sampler_reset(smpl);
667
+ for (auto & smpl : chain->samplers) {
668
+ llama_sampler_reset(smpl.ptr);
473
669
  }
474
-
475
- chain->t_sample_us = 0;
476
- chain->n_sample = 0;
477
670
  }
478
671
 
479
672
  static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
@@ -481,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
481
674
 
482
675
  auto * result = llama_sampler_chain_init(chain_src->params);
483
676
 
484
- for (auto * smpl : chain_src->samplers) {
485
- llama_sampler_chain_add(result, llama_sampler_clone(smpl));
677
+ for (const auto & smpl : chain_src->samplers) {
678
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
486
679
  }
487
680
 
488
681
  return result;
@@ -491,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
491
684
  static void llama_sampler_chain_free(struct llama_sampler * smpl) {
492
685
  auto * chain = (llama_sampler_chain *) smpl->ctx;
493
686
 
494
- for (auto * smpl : chain->samplers) {
495
- llama_sampler_free(smpl);
687
+ for (auto & smpl : chain->samplers) {
688
+ llama_sampler_free(smpl.ptr);
496
689
  }
497
690
 
498
691
  delete chain;
499
692
  }
500
693
 
694
+ static bool llama_sampler_chain_backend_init(
695
+ struct llama_sampler * smpl,
696
+ ggml_backend_buffer_type_t buft) {
697
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
698
+
699
+ GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
700
+
701
+ chain->is_init = true;
702
+
703
+ bool res = true;
704
+
705
+ for (auto & smpl : chain->samplers) {
706
+ bool res_cur = true;
707
+
708
+ // to be able to run a sampler on the backend, it has to:
709
+ // - have the .backend_init() API implemented
710
+ // - return true during .backend_init()
711
+ if (smpl.ptr->iface->backend_init) {
712
+ if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
713
+ res_cur = false;
714
+ }
715
+ } else {
716
+ res_cur = false;
717
+ }
718
+
719
+ smpl.is_backend = res_cur;
720
+
721
+ res = res && res_cur;
722
+ }
723
+
724
+ return res;
725
+ }
726
+
727
+ static void llama_sampler_chain_backend_accept(
728
+ struct llama_sampler * smpl,
729
+ ggml_context * ctx,
730
+ ggml_cgraph * gf,
731
+ struct ggml_tensor * selected_token) {
732
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
733
+
734
+ for (auto & smpl : chain->samplers) {
735
+ if (!smpl.is_backend) {
736
+ break;
737
+ }
738
+
739
+ if (smpl.ptr->iface->backend_accept) {
740
+ smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
741
+ }
742
+ }
743
+ }
744
+
745
+ static void llama_sampler_chain_backend_apply(
746
+ struct llama_sampler * smpl,
747
+ struct ggml_context * ctx,
748
+ struct ggml_cgraph * gf,
749
+ struct llama_sampler_data * data) {
750
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
751
+
752
+ GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
753
+
754
+ for (auto & smpl : chain->samplers) {
755
+ if (!smpl.is_backend) {
756
+ break;
757
+ }
758
+
759
+ if (smpl.ptr->iface->backend_apply) {
760
+ smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
761
+ }
762
+ }
763
+ }
764
+
765
+ static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
766
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
767
+
768
+ for (auto & smpl : chain->samplers) {
769
+ if (!smpl.is_backend) {
770
+ break;
771
+ }
772
+
773
+ if (smpl.ptr->iface->backend_set_input) {
774
+ smpl.ptr->iface->backend_set_input(smpl.ptr);
775
+ }
776
+ }
777
+ }
778
+
501
779
  static struct llama_sampler_i llama_sampler_chain_i = {
502
- /* .name = */ llama_sampler_chain_name,
503
- /* .accept = */ llama_sampler_chain_accept,
504
- /* .apply = */ llama_sampler_chain_apply,
505
- /* .reset = */ llama_sampler_chain_reset,
506
- /* .clone = */ llama_sampler_chain_clone,
507
- /* .free = */ llama_sampler_chain_free,
780
+ /* .name = */ llama_sampler_chain_name,
781
+ /* .accept = */ llama_sampler_chain_accept,
782
+ /* .apply = */ llama_sampler_chain_apply,
783
+ /* .reset = */ llama_sampler_chain_reset,
784
+ /* .clone = */ llama_sampler_chain_clone,
785
+ /* .free = */ llama_sampler_chain_free,
786
+ /* .backend_init = */ llama_sampler_chain_backend_init,
787
+ /* .backend_accept = */ llama_sampler_chain_backend_accept,
788
+ /* .backend_apply = */ llama_sampler_chain_backend_apply,
789
+ /* .backend_set_input = */ llama_sampler_chain_backend_set_input,
508
790
  };
509
791
 
510
792
  struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@@ -512,26 +794,113 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
512
794
  /* .iface = */ &llama_sampler_chain_i,
513
795
  /* .ctx = */ new llama_sampler_chain {
514
796
  /* .params = */ params,
797
+ /* .is_init = */ false,
515
798
  /* .samplers = */ {},
799
+ /* .cur = */ {},
516
800
  /* .t_sample_us = */ 0,
517
801
  /* .n_sample = */ 0,
518
802
  }
519
803
  );
520
804
  }
521
805
 
806
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
807
+ const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
808
+ const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
809
+ const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
810
+ const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
811
+
812
+ // If a backend sampler has already sampled a token, return it.
813
+ if (sampled_token != LLAMA_TOKEN_NULL) {
814
+ LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
815
+ return sampled_token;
816
+ }
817
+
818
+ const llama_model * model = llama_get_model(ctx);
819
+ const llama_vocab * vocab = llama_model_get_vocab(model);
820
+
821
+ const int n_vocab = llama_vocab_n_tokens(vocab);
822
+
823
+ // use pre-allocated buffer from chain if available, otherwise allocate locally
824
+ std::vector<llama_token_data> * cur_ptr;
825
+ std::vector<llama_token_data> cur_local;
826
+
827
+ if (smpl->iface == &llama_sampler_chain_i) {
828
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
829
+ cur_ptr = &chain->cur;
830
+ } else {
831
+ cur_ptr = &cur_local;
832
+ }
833
+
834
+ auto & cur = *cur_ptr;
835
+
836
+ if (sampled_probs) {
837
+ const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
838
+ cur.resize(sampled_probs_count);
839
+ for (uint32_t i = 0; i < sampled_probs_count; ++i) {
840
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
841
+ }
842
+ } else if (sampled_logits) {
843
+ const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
844
+ cur.resize(sampled_logits_count);
845
+ for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
846
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
847
+ }
848
+ } else {
849
+ const auto * logits = llama_get_logits_ith(ctx, idx);
850
+ GGML_ASSERT(logits != nullptr);
851
+ cur.resize(n_vocab);
852
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
853
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
854
+ }
855
+ }
856
+
857
+ llama_token_data_array cur_p = {
858
+ /* .data = */ cur.data(),
859
+ /* .size = */ cur.size(),
860
+ /* .selected = */ -1,
861
+ /* .sorted = */ false,
862
+ };
863
+
864
+ llama_sampler_apply(smpl, &cur_p);
865
+
866
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
867
+
868
+ auto token = cur_p.data[cur_p.selected].id;
869
+
870
+ llama_sampler_accept(smpl, token);
871
+
872
+ return token;
873
+ }
874
+
875
+
522
876
  void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
523
877
  auto * p = (llama_sampler_chain *) chain->ctx;
524
- p->samplers.push_back(smpl);
878
+ p->samplers.push_back({
879
+ /* .is_backend = */ false,
880
+ /* .ptr = */ smpl,
881
+ });
525
882
  }
526
883
 
527
- struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
884
+ struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
885
+ if (chain == nullptr) {
886
+ return nullptr;
887
+ }
888
+
889
+ if (chain->iface != &llama_sampler_chain_i) {
890
+ return nullptr;
891
+ }
892
+
893
+ if (i == -1) {
894
+ return chain;
895
+ }
896
+
528
897
  const auto * p = (const llama_sampler_chain *) chain->ctx;
529
898
 
530
899
  if (i < 0 || (size_t) i >= p->samplers.size()) {
531
900
  return nullptr;
532
901
  }
533
902
 
534
- return p->samplers[i];
903
+ return p->samplers[i].ptr;
535
904
  }
536
905
 
537
906
  struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
@@ -541,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
541
910
  return nullptr;
542
911
  }
543
912
 
544
- auto * result = p->samplers[i];
913
+ auto * result = p->samplers[i].ptr;
545
914
  p->samplers.erase(p->samplers.begin() + i);
546
915
 
547
916
  return result;
@@ -559,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
559
928
 
560
929
  // greedy
561
930
 
562
- static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
563
- return "greedy";
931
+ struct llama_sampler_greedy : public llama_sampler_backend {
932
+ };
933
+
934
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
935
+ auto * sctx = (llama_sampler_greedy *) smpl->ctx;
936
+ return sctx->get_name();
937
+ }
938
+
939
+ static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
940
+ auto * ctx = (llama_sampler_greedy *) smpl->ctx;
941
+ GGML_UNUSED(ctx);
942
+ }
943
+
944
+ static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
945
+ const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
946
+ auto * result = llama_sampler_init_greedy();
947
+
948
+ // copy the state
949
+ {
950
+ auto * result_ctx = (llama_sampler_greedy *) result->ctx;
951
+
952
+ GGML_UNUSED(ctx);
953
+ GGML_UNUSED(result_ctx);
954
+ }
955
+
956
+ return result;
957
+ }
958
+
959
+ static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
960
+ delete (llama_sampler_greedy *) smpl->ctx;
564
961
  }
565
962
 
566
963
  static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
@@ -572,33 +969,68 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
572
969
  }
573
970
  }
574
971
 
972
+ static bool llama_sampler_greedy_backend_init(
973
+ struct llama_sampler * smpl,
974
+ ggml_backend_buffer_type_t buft) {
975
+ auto * sctx = (llama_sampler_greedy *) smpl->ctx;
976
+
977
+ const bool res = llama_sampler_backend_support(smpl, buft);
978
+
979
+ sctx->init(res);
980
+
981
+ return res;
982
+ }
983
+
984
+ static void llama_sampler_greedy_backend_apply(
985
+ struct llama_sampler * smpl,
986
+ struct ggml_context * ctx,
987
+ struct ggml_cgraph * gf,
988
+ struct llama_sampler_data * data) {
989
+ GGML_UNUSED(gf);
990
+ GGML_UNUSED(smpl);
991
+
992
+ struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
993
+ ggml_set_name(curl, "greedy_argmax");
994
+
995
+ data->sampled = curl;
996
+ }
997
+
575
998
  static struct llama_sampler_i llama_sampler_greedy_i = {
576
- /* .name = */ llama_sampler_greedy_name,
577
- /* .accept = */ nullptr,
578
- /* .apply = */ llama_sampler_greedy_apply,
579
- /* .reset = */ nullptr,
580
- /* .clone = */ nullptr,
581
- /* .free = */ nullptr,
999
+ /* .name = */ llama_sampler_greedy_name,
1000
+ /* .accept = */ nullptr,
1001
+ /* .apply = */ llama_sampler_greedy_apply,
1002
+ /* .reset = */ llama_sampler_greedy_reset,
1003
+ /* .clone = */ llama_sampler_greedy_clone,
1004
+ /* .free = */ llama_sampler_greedy_free,
1005
+ /* .backend_init = */ llama_sampler_greedy_backend_init,
1006
+ /* .backend_accept = */ nullptr,
1007
+ /* .backend_apply = */ llama_sampler_greedy_backend_apply,
1008
+ /* .backend_set_input = */ nullptr,
582
1009
  };
583
1010
 
584
1011
  struct llama_sampler * llama_sampler_init_greedy() {
585
1012
  return llama_sampler_init(
586
1013
  /* .iface = */ &llama_sampler_greedy_i,
587
- /* .ctx = */ nullptr
1014
+ /* .ctx = */ new llama_sampler_greedy {
1015
+ ("greedy"),
1016
+ }
588
1017
  );
589
1018
  }
590
1019
 
591
1020
  // dist
592
1021
 
593
- struct llama_sampler_dist {
1022
+ struct llama_sampler_dist : public llama_sampler_backend {
594
1023
  const uint32_t seed;
595
1024
  uint32_t seed_cur;
596
1025
 
597
1026
  std::mt19937 rng;
1027
+
1028
+ ggml_tensor * inp_uniform;
598
1029
  };
599
1030
 
600
- static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
601
- return "dist";
1031
+ static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
1032
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1033
+ return sctx->get_name();
602
1034
  }
603
1035
 
604
1036
  static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -673,6 +1105,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
673
1105
  #endif
674
1106
  }
675
1107
 
1108
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
1109
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
1110
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1111
+ ctx->rng.seed(ctx->seed_cur);
1112
+ }
1113
+
676
1114
  static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
677
1115
  const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
678
1116
  auto * result = llama_sampler_init_dist(ctx->seed);
@@ -687,23 +1125,106 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
687
1125
  return result;
688
1126
  }
689
1127
 
690
- static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
691
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
692
- ctx->seed_cur = get_rng_seed(ctx->seed);
693
- ctx->rng.seed(ctx->seed_cur);
694
- }
695
-
696
1128
  static void llama_sampler_dist_free(struct llama_sampler * smpl) {
697
1129
  delete (llama_sampler_dist *) smpl->ctx;
698
1130
  }
699
1131
 
1132
+ static bool llama_sampler_dist_backend_init(
1133
+ struct llama_sampler * smpl,
1134
+ ggml_backend_buffer_type_t buft) {
1135
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1136
+
1137
+ const bool res = llama_sampler_backend_support(smpl, buft);
1138
+
1139
+ sctx->init(res);
1140
+
1141
+ return res;
1142
+ }
1143
+
1144
+ static void llama_sampler_dist_backend_apply(
1145
+ struct llama_sampler * smpl,
1146
+ struct ggml_context * ctx,
1147
+ struct ggml_cgraph * gf,
1148
+ struct llama_sampler_data * data) {
1149
+ GGML_UNUSED(gf);
1150
+
1151
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1152
+
1153
+ sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
1154
+ ggml_set_name (sctx->inp_uniform, "uniform");
1155
+ ggml_set_input(sctx->inp_uniform);
1156
+
1157
+ struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
1158
+ ggml_set_name(probs, "dist_probs");
1159
+
1160
+ struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
1161
+ ggml_set_name(cumsum, "dist_cumsum");
1162
+
1163
+ // The uniform tensor has a random value and we subtract this tensor with
1164
+ // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
1165
+ // Recall that each entry in cumsum is the cumulative probability up to that
1166
+ // index so values stay negative while the cumulative total is below the
1167
+ // random value, and become zero/positive once the threshold is crossed.
1168
+ struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
1169
+ ggml_set_name(diff, "dist_cumsum");
1170
+
1171
+ // The ggml_step function produces a tensor where entries are 1 if the
1172
+ // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
1173
+ // the index where the cumulative probability exceeds the random value are 0,
1174
+ // and all entries after that are 1.
1175
+ struct ggml_tensor * mask = ggml_step(ctx, diff);
1176
+ ggml_set_name(mask, "dist_mask");
1177
+
1178
+ // Taking the sum of the mask gives us the sum of elements after the threshold
1179
+ // we are interested in.
1180
+ struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1181
+ ggml_set_name(idxf, "dist_index_f32");
1182
+
1183
+ // Use ggml_scale_bias to scale the index value by -1 and then add the size
1184
+ // of the mask to that value so we get the correct index ((-1 * idxf) + n).
1185
+ struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
1186
+ ggml_set_name(idx, "dist_index_i32");
1187
+
1188
+ // Map back to original vocab ids if a candidates tensor is available.
1189
+ struct ggml_tensor * sampled_token = idx;
1190
+ if (data->candidates != nullptr) {
1191
+ struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
1192
+
1193
+ sampled_token = ggml_get_rows(ctx, candidates, idx);
1194
+ ggml_set_name(sampled_token, "dist_sampled_token");
1195
+ }
1196
+
1197
+ data->sampled = sampled_token;
1198
+ data->probs = probs;
1199
+ }
1200
+
1201
+ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
1202
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1203
+
1204
+ GGML_ASSERT(sctx->inp_uniform != nullptr);
1205
+
1206
+ // We sample in double precision and cast to float to match rnd numbers of
1207
+ // llama_dampler_dist which uses double precision (sampling from
1208
+ // std::uniform_real_distribution<double> and
1209
+ // std::uniform_real_distribution<float> with same rng will produce
1210
+ // different sequences).
1211
+ std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1212
+ const float rnd = dist(sctx->rng);
1213
+
1214
+ ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
1215
+ }
1216
+
700
1217
  static struct llama_sampler_i llama_sampler_dist_i = {
701
- /* .name = */ llama_sampler_dist_name,
702
- /* .accept = */ nullptr,
703
- /* .apply = */ llama_sampler_dist_apply,
704
- /* .reset = */ llama_sampler_dist_reset,
705
- /* .clone = */ llama_sampler_dist_clone,
706
- /* .free = */ llama_sampler_dist_free,
1218
+ /* .name = */ llama_sampler_dist_name,
1219
+ /* .accept = */ nullptr,
1220
+ /* .apply = */ llama_sampler_dist_apply,
1221
+ /* .reset = */ llama_sampler_dist_reset,
1222
+ /* .clone = */ llama_sampler_dist_clone,
1223
+ /* .free = */ llama_sampler_dist_free,
1224
+ /* .backend_init = */ llama_sampler_dist_backend_init,
1225
+ /* .backend_accept = */ nullptr,
1226
+ /* .backend_apply = */ llama_sampler_dist_backend_apply,
1227
+ /* .backend_set_input = */ llama_sampler_dist_backend_set_input,
707
1228
  };
708
1229
 
709
1230
  struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
@@ -711,21 +1232,24 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
711
1232
  return llama_sampler_init(
712
1233
  /* .iface = */ &llama_sampler_dist_i,
713
1234
  /* .ctx = */ new llama_sampler_dist {
714
- /* .seed = */ seed,
715
- /* .seed_cur = */ seed_cur,
716
- /* .rng = */ std::mt19937(seed_cur),
1235
+ ("dist"),
1236
+ /* .seed = */ seed,
1237
+ /* .seed_cur = */ seed_cur,
1238
+ /* .rng = */ std::mt19937(seed_cur),
1239
+ /* .inp_uniform = */ nullptr,
717
1240
  }
718
1241
  );
719
1242
  }
720
1243
 
721
1244
  // top-k
722
1245
 
723
- struct llama_sampler_top_k {
1246
+ struct llama_sampler_top_k : public llama_sampler_backend {
724
1247
  const int32_t k;
725
1248
  };
726
1249
 
727
- static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
728
- return "top-k";
1250
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
1251
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1252
+ return sctx->get_name();
729
1253
  }
730
1254
 
731
1255
  static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -742,19 +1266,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
742
1266
  delete (llama_sampler_top_k *) smpl->ctx;
743
1267
  }
744
1268
 
1269
+ static bool llama_sampler_top_k_backend_init(
1270
+ struct llama_sampler * smpl,
1271
+ ggml_backend_buffer_type_t buft) {
1272
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1273
+
1274
+ const bool res = llama_sampler_backend_support(smpl, buft);
1275
+
1276
+ sctx->init(res);
1277
+
1278
+ return res;
1279
+ }
1280
+
1281
+ static void llama_sampler_top_k_backend_apply(
1282
+ struct llama_sampler * smpl,
1283
+ struct ggml_context * ctx,
1284
+ struct ggml_cgraph * gf,
1285
+ struct llama_sampler_data * data) {
1286
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1287
+
1288
+ struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
1289
+ ggml_set_name(top_k, "top_k");
1290
+
1291
+ if (data->candidates) {
1292
+ struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1293
+ data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
1294
+ data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
1295
+ ggml_set_name(data->candidates, "top_k_candidates");
1296
+ } else {
1297
+ data->candidates = top_k;
1298
+ }
1299
+
1300
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1301
+ struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
1302
+ data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
1303
+ ggml_set_name(top_k_rows, "top_k_rows");
1304
+
1305
+ GGML_UNUSED(gf);
1306
+ }
1307
+
745
1308
  static struct llama_sampler_i llama_sampler_top_k_i = {
746
- /* .name = */ llama_sampler_top_k_name,
747
- /* .accept = */ nullptr,
748
- /* .apply = */ llama_sampler_top_k_apply,
749
- /* .reset = */ nullptr,
750
- /* .clone = */ llama_sampler_top_k_clone,
751
- /* .free = */ llama_sampler_top_k_free,
1309
+ /* .name = */ llama_sampler_top_k_name,
1310
+ /* .accept = */ nullptr,
1311
+ /* .apply = */ llama_sampler_top_k_apply,
1312
+ /* .reset = */ nullptr,
1313
+ /* .clone = */ llama_sampler_top_k_clone,
1314
+ /* .free = */ llama_sampler_top_k_free,
1315
+ /* .backend_init = */ llama_sampler_top_k_backend_init,
1316
+ /* .backend_accept = */ nullptr,
1317
+ /* .backend_apply = */ llama_sampler_top_k_backend_apply,
1318
+ /* .backend_set_input = */ nullptr,
752
1319
  };
753
1320
 
754
1321
  struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
1322
+ const bool is_empty = (k <= 0);
1323
+
1324
+ if (is_empty) {
1325
+ return llama_sampler_init_empty("?top-k");
1326
+ }
1327
+
755
1328
  return llama_sampler_init(
756
1329
  /* .iface = */ &llama_sampler_top_k_i,
757
1330
  /* .ctx = */ new llama_sampler_top_k {
1331
+ ("top-k"),
758
1332
  /* .k = */ k,
759
1333
  }
760
1334
  );
@@ -762,15 +1336,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
762
1336
 
763
1337
  // top-p
764
1338
 
765
- struct llama_sampler_top_p {
1339
+ struct llama_sampler_top_p : public llama_sampler_backend {
766
1340
  const float p;
767
1341
  const size_t min_keep;
768
1342
 
769
1343
  std::vector<llama_token_data> buf_sort;
770
1344
  };
771
1345
 
772
- static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
773
- return "top-p";
1346
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
1347
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1348
+ return sctx->get_name();
774
1349
  }
775
1350
 
776
1351
  static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -837,19 +1412,115 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
837
1412
  delete (llama_sampler_top_p *) smpl->ctx;
838
1413
  }
839
1414
 
1415
+ static bool llama_sampler_top_p_backend_init(
1416
+ struct llama_sampler * smpl,
1417
+ ggml_backend_buffer_type_t buft) {
1418
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1419
+
1420
+ const bool res = llama_sampler_backend_support(smpl, buft);
1421
+
1422
+ sctx->init(res);
1423
+
1424
+ return res;
1425
+ }
1426
+
1427
+ static void llama_sampler_top_p_backend_apply(
1428
+ struct llama_sampler * smpl,
1429
+ struct ggml_context * ctx,
1430
+ struct ggml_cgraph * gf,
1431
+ struct llama_sampler_data * data) {
1432
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1433
+
1434
+ auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
1435
+ GGML_ASSERT(ggml_nrows(a) == 1);
1436
+ struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
1437
+ struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
1438
+ return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
1439
+ };
1440
+
1441
+ // Get the sorted logits in descending order.
1442
+ struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
1443
+ ggml_set_name(sorted_idx, "top_p_sorted_idx");
1444
+
1445
+ // Do the sorting via reshape + get_rows
1446
+ struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
1447
+ ggml_set_name(sorted_logits, "top_p_sorted_logits");
1448
+
1449
+ struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
1450
+ ggml_set_name(softmax, "top_p_softmax");
1451
+
1452
+ // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
1453
+ if (data->candidates) {
1454
+ data->candidates = ggml_sort(data->candidates, sorted_idx);
1455
+ } else {
1456
+ data->candidates = sorted_idx;
1457
+ }
1458
+ ggml_set_name(data->candidates, "top_p_candidates");
1459
+
1460
+ // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
1461
+ struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
1462
+ ggml_set_name(cdf, "top_p_cdf");
1463
+
1464
+ // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
1465
+ struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
1466
+ ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
1467
+
1468
+ struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
1469
+ ggml_set_name(mask, "top_p_mask");
1470
+
1471
+ // Taking the sum of the mask gives us the sum of elements after the threshold
1472
+ // we are interested in.
1473
+ struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1474
+ ggml_set_name(idxf, "top_p_index_f32");
1475
+
1476
+ // prevent out-of-bounds access
1477
+ idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
1478
+
1479
+ // construct ones tensor to set the value in the mask
1480
+ struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
1481
+ ggml_set_name(ones, "top_p_ones");
1482
+
1483
+ // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
1484
+ struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
1485
+
1486
+ mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
1487
+ mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
1488
+
1489
+ // Apply -INFINITY bias for masked-out tokens
1490
+ // log(1) = 0 (keep), log(0) = -INF (discard)
1491
+ struct ggml_tensor * top_p_bias = ggml_log(ctx, mask);
1492
+ ggml_set_name(top_p_bias, "top_p_bias");
1493
+
1494
+ data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
1495
+ ggml_set_name(data->logits, "top_p_logits");
1496
+
1497
+ GGML_UNUSED(gf);
1498
+ }
1499
+
840
1500
  static struct llama_sampler_i llama_sampler_top_p_i = {
841
- /* .name = */ llama_sampler_top_p_name,
842
- /* .accept = */ nullptr,
843
- /* .apply = */ llama_sampler_top_p_apply,
844
- /* .reset = */ nullptr,
845
- /* .clone = */ llama_sampler_top_p_clone,
846
- /* .free = */ llama_sampler_top_p_free,
1501
+ /* .name = */ llama_sampler_top_p_name,
1502
+ /* .accept = */ nullptr,
1503
+ /* .apply = */ llama_sampler_top_p_apply,
1504
+ /* .reset = */ nullptr,
1505
+ /* .clone = */ llama_sampler_top_p_clone,
1506
+ /* .free = */ llama_sampler_top_p_free,
1507
+ /* .backend_init = */ llama_sampler_top_p_backend_init,
1508
+ /* .backend_accept = */ nullptr,
1509
+ /* .backend_apply = */ llama_sampler_top_p_backend_apply,
1510
+ /* .backend_set_input = */ nullptr,
847
1511
  };
848
1512
 
849
1513
  struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
1514
+ const bool is_empty = p >= 1.0f;
1515
+
1516
+ if (is_empty) {
1517
+ return llama_sampler_init_empty("?top-p");
1518
+ }
1519
+
850
1520
  return llama_sampler_init(
851
1521
  /* .iface = */ &llama_sampler_top_p_i,
852
1522
  /* .ctx = */ new llama_sampler_top_p {
1523
+ ("top-p"),
853
1524
  /* .p = */ p,
854
1525
  /* .min_keep = */ min_keep,
855
1526
  /* .buf_sort = */ {},
@@ -859,13 +1530,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
859
1530
 
860
1531
  // min-p
861
1532
 
862
- struct llama_sampler_min_p {
1533
+ struct llama_sampler_min_p : public llama_sampler_backend {
863
1534
  const float p;
864
1535
  const size_t min_keep;
865
1536
  };
866
1537
 
867
- static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
868
- return "min-p";
1538
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
1539
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1540
+ return sctx->get_name();
869
1541
  }
870
1542
 
871
1543
  static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -931,19 +1603,81 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
931
1603
  delete (llama_sampler_min_p *) smpl->ctx;
932
1604
  }
933
1605
 
1606
+ static bool llama_sampler_min_p_backend_init(
1607
+ struct llama_sampler * smpl,
1608
+ ggml_backend_buffer_type_t buft) {
1609
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1610
+
1611
+ const bool res = llama_sampler_backend_support(smpl, buft);
1612
+
1613
+ sctx->init(res);
1614
+
1615
+ return res;
1616
+ }
1617
+
1618
+ static void llama_sampler_min_p_backend_apply(
1619
+ struct llama_sampler * smpl,
1620
+ struct ggml_context * ctx,
1621
+ struct ggml_cgraph * gf,
1622
+ struct llama_sampler_data * data) {
1623
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1624
+
1625
+ struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1626
+ ggml_set_name(max_idx, "max_idx");
1627
+
1628
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1629
+ ggml_set_name(logits_rows, "logits_rows");
1630
+
1631
+ struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
1632
+ ggml_set_name(max_logit, "max_logit");
1633
+
1634
+ // Calculate the threshold value.
1635
+ struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
1636
+ ggml_set_name(threshold, "min_p_threshold");
1637
+
1638
+ // Subtract the threshold from logits.
1639
+ struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
1640
+
1641
+ // Create a mask where logits below the threshold are 0 (discard),
1642
+ // and others are 1 (keep).
1643
+ struct ggml_tensor * mask = ggml_step(ctx, sub);
1644
+ ggml_set_name(mask, "min_p_mask");
1645
+
1646
+ // Apply -INFINITY bias for masked-out tokens
1647
+ // log(1) = 0 (keep), log(0) = -INF (discard)
1648
+ struct ggml_tensor * min_p_bias = ggml_log(ctx, mask);
1649
+ ggml_set_name(min_p_bias, "min_p_bias");
1650
+
1651
+ data->logits = ggml_add(ctx, data->logits, min_p_bias);
1652
+ ggml_set_name(data->logits, "min_p_logits");
1653
+
1654
+ GGML_UNUSED(gf);
1655
+ }
1656
+
934
1657
  static struct llama_sampler_i llama_sampler_min_p_i = {
935
- /* .name = */ llama_sampler_min_p_name,
936
- /* .accept = */ nullptr,
937
- /* .apply = */ llama_sampler_min_p_apply,
938
- /* .reset = */ nullptr,
939
- /* .clone = */ llama_sampler_min_p_clone,
940
- /* .free = */ llama_sampler_min_p_free,
1658
+ /* .name = */ llama_sampler_min_p_name,
1659
+ /* .accept = */ nullptr,
1660
+ /* .apply = */ llama_sampler_min_p_apply,
1661
+ /* .reset = */ nullptr,
1662
+ /* .clone = */ llama_sampler_min_p_clone,
1663
+ /* .free = */ llama_sampler_min_p_free,
1664
+ /* .backend_init = */ llama_sampler_min_p_backend_init,
1665
+ /* .backend_accept = */ nullptr,
1666
+ /* .backend_apply = */ llama_sampler_min_p_backend_apply,
1667
+ /* .backend_set_input = */ nullptr,
941
1668
  };
942
1669
 
943
1670
  struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
1671
+ const bool is_empty = (p <= 0.0f);
1672
+
1673
+ if (is_empty) {
1674
+ return llama_sampler_init_empty("?min-p");
1675
+ }
1676
+
944
1677
  return llama_sampler_init(
945
1678
  /* .iface = */ &llama_sampler_min_p_i,
946
1679
  /* .ctx = */ new llama_sampler_min_p {
1680
+ ("min-p"),
947
1681
  /* .p = */ p,
948
1682
  /* .min_keep = */ min_keep,
949
1683
  }
@@ -1031,15 +1765,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
1031
1765
  }
1032
1766
 
1033
1767
  static struct llama_sampler_i llama_sampler_typical_i = {
1034
- /* .name = */ llama_sampler_typical_name,
1035
- /* .accept = */ nullptr,
1036
- /* .apply = */ llama_sampler_typical_apply,
1037
- /* .reset = */ nullptr,
1038
- /* .clone = */ llama_sampler_typical_clone,
1039
- /* .free = */ llama_sampler_typical_free,
1768
+ /* .name = */ llama_sampler_typical_name,
1769
+ /* .accept = */ nullptr,
1770
+ /* .apply = */ llama_sampler_typical_apply,
1771
+ /* .reset = */ nullptr,
1772
+ /* .clone = */ llama_sampler_typical_clone,
1773
+ /* .free = */ llama_sampler_typical_free,
1774
+ /* .backend_init = */ nullptr,
1775
+ /* .backend_accept = */ nullptr,
1776
+ /* .backend_apply = */ nullptr,
1777
+ /* .backend_set_input = */ nullptr,
1040
1778
  };
1041
1779
 
1042
1780
  struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1781
+ const bool is_empty = (p >= 1.0f);
1782
+
1783
+ if (is_empty) {
1784
+ return llama_sampler_init_empty("?typical");
1785
+ }
1786
+
1043
1787
  return llama_sampler_init(
1044
1788
  /* .iface = */ &llama_sampler_typical_i,
1045
1789
  /* .ctx = */ new llama_sampler_typical {
@@ -1051,12 +1795,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1051
1795
 
1052
1796
  // temp
1053
1797
 
1054
- struct llama_sampler_temp {
1798
+ struct llama_sampler_temp : public llama_sampler_backend {
1055
1799
  const float temp;
1056
1800
  };
1057
1801
 
1058
- static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
1059
- return "temp";
1802
+ static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
1803
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1804
+ return sctx->get_name();
1060
1805
  }
1061
1806
 
1062
1807
  static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -1074,19 +1819,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1074
1819
  delete (llama_sampler_temp *) smpl->ctx;
1075
1820
  }
1076
1821
 
1822
+ static void llama_sampler_backend_temp_sampling(
1823
+ struct ggml_context * ctx,
1824
+ struct ggml_cgraph * gf,
1825
+ struct llama_sampler_data * data,
1826
+ float temp) {
1827
+ if (temp <= 0.0f) {
1828
+ // Find the most probable token index.
1829
+ struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1830
+ ggml_set_name(max_idx, "temp_max_idx");
1831
+
1832
+ if (data->candidates) {
1833
+ struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1834
+ data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
1835
+ } else {
1836
+ data->candidates = max_idx;
1837
+ }
1838
+
1839
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1840
+ data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
1841
+
1842
+ return;
1843
+ }
1844
+
1845
+ data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
1846
+
1847
+ GGML_UNUSED(gf);
1848
+ }
1849
+
1850
+ static bool llama_sampler_temp_backend_init(
1851
+ struct llama_sampler * smpl,
1852
+ ggml_backend_buffer_type_t buft) {
1853
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1854
+
1855
+ const bool res = llama_sampler_backend_support(smpl, buft);
1856
+
1857
+ sctx->init(res);
1858
+
1859
+ return res;
1860
+ }
1861
+
1862
+ static void llama_sampler_temp_backend_apply(
1863
+ struct llama_sampler * smpl,
1864
+ struct ggml_context * ctx,
1865
+ struct ggml_cgraph * gf,
1866
+ struct llama_sampler_data * data) {
1867
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1868
+ llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
1869
+ }
1870
+
1077
1871
  static struct llama_sampler_i llama_sampler_temp_i = {
1078
- /* .name = */ llama_sampler_temp_name,
1079
- /* .accept = */ nullptr,
1080
- /* .apply = */ llama_sampler_temp_apply,
1081
- /* .reset = */ nullptr,
1082
- /* .clone = */ llama_sampler_temp_clone,
1083
- /* .free = */ llama_sampler_temp_free,
1872
+ /* .name = */ llama_sampler_temp_name,
1873
+ /* .accept = */ nullptr,
1874
+ /* .apply = */ llama_sampler_temp_apply,
1875
+ /* .reset = */ nullptr,
1876
+ /* .clone = */ llama_sampler_temp_clone,
1877
+ /* .free = */ llama_sampler_temp_free,
1878
+ /* .backend_init = */ llama_sampler_temp_backend_init,
1879
+ /* .backend_accept = */ nullptr,
1880
+ /* .backend_apply = */ llama_sampler_temp_backend_apply,
1881
+ /* .backend_set_input = */ nullptr,
1084
1882
  };
1085
1883
 
1086
1884
  struct llama_sampler * llama_sampler_init_temp(float temp) {
1885
+ const bool is_empty = temp == 1.0f;
1886
+
1887
+ if (is_empty) {
1888
+ return llama_sampler_init_empty("?temp");
1889
+ }
1890
+
1087
1891
  return llama_sampler_init(
1088
1892
  /* .iface = */ &llama_sampler_temp_i,
1089
1893
  /* .ctx = */ new llama_sampler_temp {
1894
+ ("temp"),
1090
1895
  /*.temp = */ temp,
1091
1896
  }
1092
1897
  );
@@ -1094,14 +1899,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
1094
1899
 
1095
1900
  // temp-ext
1096
1901
 
1097
- struct llama_sampler_temp_ext {
1902
+ struct llama_sampler_temp_ext : public llama_sampler_backend {
1098
1903
  const float temp;
1099
1904
  const float delta;
1100
1905
  const float exponent;
1101
1906
  };
1102
1907
 
1103
- static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1104
- return "temp-ext";
1908
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
1909
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
1910
+ return sctx->get_name();
1105
1911
  }
1106
1912
 
1107
1913
  static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -1184,24 +1990,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1184
1990
  delete (llama_sampler_temp_ext *) smpl->ctx;
1185
1991
  }
1186
1992
 
1993
+ static bool llama_sampler_temp_ext_backend_init(
1994
+ struct llama_sampler * smpl,
1995
+ ggml_backend_buffer_type_t buft) {
1996
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
1997
+
1998
+ const bool res = llama_sampler_backend_support(smpl, buft);
1999
+
2000
+ sctx->init(res);
2001
+
2002
+ return res;
2003
+ }
2004
+
2005
+ static void llama_sampler_temp_ext_backend_apply(
2006
+ struct llama_sampler * smpl,
2007
+ struct ggml_context * ctx,
2008
+ struct ggml_cgraph * gf,
2009
+ struct llama_sampler_data * data) {
2010
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2011
+
2012
+ // Revert to standard temperature scaling if delta or temp are non-positive.
2013
+ if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
2014
+ llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
2015
+ return;
2016
+ }
2017
+
2018
+ // Calculate min_temp, max_temp, and max_entropy.
2019
+ const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
2020
+ const float max_temp = sctx->temp + sctx->delta;
2021
+ const float max_entropy = logf(data->logits->ne[0]);
2022
+
2023
+ // Calculate the probabilities.
2024
+ struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
2025
+ ggml_set_name(probs, "temp_ext_softmax_probs");
2026
+
2027
+ // Clamp probabilities to avoid log(0) which would give -inf
2028
+ struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
2029
+ ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
2030
+
2031
+ // Calculate the entropy, entropy = -Σ(p * log(p)).
2032
+ struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
2033
+ struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
2034
+ struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
2035
+ struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
2036
+ ggml_set_name(log_probs, "temp_ext_log_probs");
2037
+ ggml_set_name(p_log_p, "temp_ext_p_log_p");
2038
+ ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
2039
+ ggml_set_name(entropy, "temp_ext_entropy");
2040
+
2041
+ // Normalize the entropy, norm_entropy = entropy / max_entropy
2042
+ struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
2043
+ ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
2044
+
2045
+ // Calculate the dynamic temperature:
2046
+ // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
2047
+ //
2048
+ // Calculate powf(normalized_entropy, exponent) as
2049
+ // norm_entropy^exponent = exp(exponent * log(norm_entropy))
2050
+ struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
2051
+ struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
2052
+ struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
2053
+ // With pow_entropy computed we can now compute dyn_temp, scaling by
2054
+ // (max_temp - min_temp) and then adding min_temp.
2055
+ struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
2056
+ ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
2057
+ ggml_set_name(scaled_log, "temp_ext_scaled_log");
2058
+ ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
2059
+ ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
2060
+
2061
+ // Scale the logits by the dynamic temperature
2062
+ struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
2063
+ ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
2064
+
2065
+ data->logits = scaled_logits;
2066
+ }
2067
+
1187
2068
  static struct llama_sampler_i llama_sampler_temp_ext_i = {
1188
- /* .name = */ llama_sampler_temp_ext_name,
1189
- /* .accept = */ nullptr,
1190
- /* .apply = */ llama_sampler_temp_ext_apply,
1191
- /* .reset = */ nullptr,
1192
- /* .clone = */ llama_sampler_temp_ext_clone,
1193
- /* .free = */ llama_sampler_temp_ext_free,
2069
+ /* .name = */ llama_sampler_temp_ext_name,
2070
+ /* .accept = */ nullptr,
2071
+ /* .apply = */ llama_sampler_temp_ext_apply,
2072
+ /* .reset = */ nullptr,
2073
+ /* .clone = */ llama_sampler_temp_ext_clone,
2074
+ /* .free = */ llama_sampler_temp_ext_free,
2075
+ /* .backend_init = */ llama_sampler_temp_ext_backend_init,
2076
+ /* .backend_accept = */ nullptr,
2077
+ /* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
2078
+ /* .backend_set_input = */ nullptr,
1194
2079
  };
1195
2080
 
1196
2081
  struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1197
- return llama_sampler_init(
2082
+ const bool is_empty = temp == 1.0f && delta <= 0.0f;
2083
+
2084
+ if (is_empty) {
2085
+ return llama_sampler_init_empty("?temp-ext");
2086
+ }
2087
+
2088
+ auto * res = llama_sampler_init(
1198
2089
  /* .iface = */ &llama_sampler_temp_ext_i,
1199
2090
  /* .ctx = */ new llama_sampler_temp_ext {
2091
+ ("temp-ext"),
1200
2092
  /* .temp = */ temp,
1201
2093
  /* .delta = */ delta,
1202
2094
  /* .exponent = */ exponent,
1203
2095
  }
1204
2096
  );
2097
+
2098
+ return res;
1205
2099
  }
1206
2100
 
1207
2101
  // xtc
@@ -1214,7 +2108,7 @@ struct llama_sampler_xtc {
1214
2108
  const uint32_t seed;
1215
2109
  uint32_t seed_cur;
1216
2110
 
1217
- std::mt19937 rng;
2111
+ std::mt19937 rng;
1218
2112
  };
1219
2113
 
1220
2114
  static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@@ -1279,16 +2173,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1279
2173
  }
1280
2174
 
1281
2175
  static struct llama_sampler_i llama_sampler_xtc_i = {
1282
- /* .name = */ llama_sampler_xtc_name,
1283
- /* .accept = */ nullptr,
1284
- /* .apply = */ llama_sample_xtc_apply,
1285
- /* .reset = */ llama_sampler_xtc_reset,
1286
- /* .clone = */ llama_sampler_xtc_clone,
1287
- /* .free = */ llama_sampler_xtc_free,
2176
+ /* .name = */ llama_sampler_xtc_name,
2177
+ /* .accept = */ nullptr,
2178
+ /* .apply = */ llama_sample_xtc_apply,
2179
+ /* .reset = */ llama_sampler_xtc_reset,
2180
+ /* .clone = */ llama_sampler_xtc_clone,
2181
+ /* .free = */ llama_sampler_xtc_free,
2182
+ /* .backend_init = */ nullptr,
2183
+ /* .backend_accept = */ nullptr,
2184
+ /* .backend_apply = */ nullptr,
2185
+ /* .backend_set_input = */ nullptr,
1288
2186
  };
1289
2187
 
1290
2188
  struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1291
- auto seed_cur = get_rng_seed(seed);
2189
+ const bool is_empty = (p <= 0.0f || t > 0.5f);
2190
+
2191
+ if (is_empty) {
2192
+ return llama_sampler_init_empty("?xtc");
2193
+ }
2194
+
2195
+ const auto seed_cur = get_rng_seed(seed);
2196
+
1292
2197
  return llama_sampler_init(
1293
2198
  /* .iface = */ &llama_sampler_xtc_i,
1294
2199
  /* .ctx = */ new llama_sampler_xtc {
@@ -1387,16 +2292,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1387
2292
  }
1388
2293
 
1389
2294
  static struct llama_sampler_i llama_sampler_mirostat_i = {
1390
- /* .name = */ llama_sampler_mirostat_name,
1391
- /* .accept = */ nullptr,
1392
- /* .apply = */ llama_sampler_mirostat_apply,
1393
- /* .reset = */ llama_sampler_mirostat_reset,
1394
- /* .clone = */ llama_sampler_mirostat_clone,
1395
- /* .free = */ llama_sampler_mirostat_free,
2295
+ /* .name = */ llama_sampler_mirostat_name,
2296
+ /* .accept = */ nullptr,
2297
+ /* .apply = */ llama_sampler_mirostat_apply,
2298
+ /* .reset = */ llama_sampler_mirostat_reset,
2299
+ /* .clone = */ llama_sampler_mirostat_clone,
2300
+ /* .free = */ llama_sampler_mirostat_free,
2301
+ /* .backend_init = */ nullptr,
2302
+ /* .backend_accept = */ nullptr,
2303
+ /* .backend_apply = */ nullptr,
2304
+ /* .backend_set_input = */ nullptr,
1396
2305
  };
1397
2306
 
1398
2307
  struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1399
- auto seed_cur = get_rng_seed(seed);
2308
+ const auto seed_cur = get_rng_seed(seed);
2309
+
1400
2310
  return llama_sampler_init(
1401
2311
  /* .iface = */ &llama_sampler_mirostat_i,
1402
2312
  /* .ctx = */ new llama_sampler_mirostat {
@@ -1486,12 +2396,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1486
2396
  }
1487
2397
 
1488
2398
  static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1489
- /* .name = */ llama_sampler_mirostat_v2_name,
1490
- /* .accept = */ nullptr,
1491
- /* .apply = */ llama_sampler_mirostat_v2_apply,
1492
- /* .reset = */ llama_sampler_mirostat_v2_reset,
1493
- /* .clone = */ llama_sampler_mirostat_v2_clone,
1494
- /* .free = */ llama_sampler_mirostat_v2_free,
2399
+ /* .name = */ llama_sampler_mirostat_v2_name,
2400
+ /* .accept = */ nullptr,
2401
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
2402
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
2403
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
2404
+ /* .free = */ llama_sampler_mirostat_v2_free,
2405
+ /* .backend_init = */ nullptr,
2406
+ /* .backend_accept = */ nullptr,
2407
+ /* .backend_apply = */ nullptr,
2408
+ /* .backend_set_input = */ nullptr,
1495
2409
  };
1496
2410
 
1497
2411
  struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@@ -1603,12 +2517,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1603
2517
  }
1604
2518
 
1605
2519
  static struct llama_sampler_i llama_sampler_grammar_i = {
1606
- /* .name = */ llama_sampler_grammar_name,
1607
- /* .accept = */ llama_sampler_grammar_accept_impl,
1608
- /* .apply = */ llama_sampler_grammar_apply,
1609
- /* .reset = */ llama_sampler_grammar_reset,
1610
- /* .clone = */ llama_sampler_grammar_clone,
1611
- /* .free = */ llama_sampler_grammar_free,
2520
+ /* .name = */ llama_sampler_grammar_name,
2521
+ /* .accept = */ llama_sampler_grammar_accept_impl,
2522
+ /* .apply = */ llama_sampler_grammar_apply,
2523
+ /* .reset = */ llama_sampler_grammar_reset,
2524
+ /* .clone = */ llama_sampler_grammar_clone,
2525
+ /* .free = */ llama_sampler_grammar_free,
2526
+ /* .backend_init = */ nullptr,
2527
+ /* .backend_accept = */ nullptr,
2528
+ /* .backend_apply = */ nullptr,
2529
+ /* .backend_set_input = */ nullptr,
1612
2530
  };
1613
2531
 
1614
2532
  static struct llama_sampler * llama_sampler_init_grammar_impl(
@@ -1625,10 +2543,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1625
2543
  auto * ctx = new llama_sampler_grammar;
1626
2544
 
1627
2545
  if (grammar_str != nullptr && grammar_str[0] != '\0') {
2546
+ std::string trigger_pattern;
2547
+ llama_grammar * grammar = nullptr;
1628
2548
  // TODO: remove trigger_words support.
1629
2549
  if (trigger_words != nullptr && num_trigger_words > 0) {
1630
2550
  GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
1631
- std::string trigger_pattern("[\\s\\S]*?(");
2551
+ trigger_pattern = "[\\s\\S]*?(";
1632
2552
  for (size_t i = 0; i < num_trigger_words; ++i) {
1633
2553
  static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
1634
2554
  if (i > 0) {
@@ -1637,15 +2557,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1637
2557
  trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
1638
2558
  }
1639
2559
  trigger_pattern += ")[\\s\\S]*";
1640
- const auto * trigger_pattern_c = trigger_pattern.c_str();
1641
- trigger_patterns = &trigger_pattern_c;
1642
- num_trigger_patterns = 1;
2560
+
2561
+ std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
2562
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
2563
+ } else {
2564
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
1643
2565
  }
1644
2566
  *ctx = {
1645
2567
  /* .vocab = */ vocab,
1646
2568
  /* .grammar_str = */ grammar_str,
1647
2569
  /* .grammar_root = */ grammar_root,
1648
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
2570
+ /* .grammar = */ grammar,
1649
2571
  };
1650
2572
  if (!ctx->grammar) {
1651
2573
  delete ctx;
@@ -1806,12 +2728,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1806
2728
  }
1807
2729
 
1808
2730
  static struct llama_sampler_i llama_sampler_penalties_i = {
1809
- /* .name = */ llama_sampler_penalties_name,
1810
- /* .accept = */ llama_sampler_penalties_accept,
1811
- /* .apply = */ llama_sampler_penalties_apply,
1812
- /* .reset = */ llama_sampler_penalties_reset,
1813
- /* .clone = */ llama_sampler_penalties_clone,
1814
- /* .free = */ llama_sampler_penalties_free,
2731
+ /* .name = */ llama_sampler_penalties_name,
2732
+ /* .accept = */ llama_sampler_penalties_accept,
2733
+ /* .apply = */ llama_sampler_penalties_apply,
2734
+ /* .reset = */ llama_sampler_penalties_reset,
2735
+ /* .clone = */ llama_sampler_penalties_clone,
2736
+ /* .free = */ llama_sampler_penalties_free,
2737
+ /* .backend_init = */ nullptr,
2738
+ /* .backend_accept = */ nullptr,
2739
+ /* .backend_apply = */ nullptr,
2740
+ /* .backend_set_input = */ nullptr,
1815
2741
  };
1816
2742
 
1817
2743
  struct llama_sampler * llama_sampler_init_penalties(
@@ -1821,6 +2747,12 @@ struct llama_sampler * llama_sampler_init_penalties(
1821
2747
  float penalty_present) {
1822
2748
  penalty_last_n = std::max(penalty_last_n, 0);
1823
2749
 
2750
+ const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
2751
+
2752
+ if (is_empty) {
2753
+ return llama_sampler_init_empty("?penalties");
2754
+ }
2755
+
1824
2756
  return llama_sampler_init(
1825
2757
  /* .iface = */ &llama_sampler_penalties_i,
1826
2758
  /* .ctx = */ new llama_sampler_penalties {
@@ -1858,9 +2790,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
1858
2790
  for (size_t i = 0; i < cur_p->size; ++i) {
1859
2791
  // Only count non-negative infinity values
1860
2792
  if (cur_p->data[i].logit != -INFINITY) {
1861
- if (cur_p->data[i].logit > max) {
1862
- max = cur_p->data[i].logit;
1863
- }
2793
+ max = std::max(max, cur_p->data[i].logit);
1864
2794
  logits_sum += cur_p->data[i].logit;
1865
2795
  valid_count++;
1866
2796
  }
@@ -1897,15 +2827,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1897
2827
  }
1898
2828
 
1899
2829
  static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1900
- /* .name = */ llama_sampler_top_n_sigma_name,
1901
- /* .accept = */ nullptr,
1902
- /* .apply = */ llama_sampler_top_n_sigma_apply,
1903
- /* .reset = */ nullptr,
1904
- /* .clone = */ llama_sampler_top_n_sigma_clone,
1905
- /* .free = */ llama_sampler_top_n_sigma_free,
2830
+ /* .name = */ llama_sampler_top_n_sigma_name,
2831
+ /* .accept = */ nullptr,
2832
+ /* .apply = */ llama_sampler_top_n_sigma_apply,
2833
+ /* .reset = */ nullptr,
2834
+ /* .clone = */ llama_sampler_top_n_sigma_clone,
2835
+ /* .free = */ llama_sampler_top_n_sigma_free,
2836
+ /* .backend_init = */ nullptr,
2837
+ /* .backend_accept = */ nullptr,
2838
+ /* .backend_apply = */ nullptr,
2839
+ /* .backend_set_input = */ nullptr,
1906
2840
  };
1907
2841
 
1908
2842
  struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
2843
+ const bool is_empty = (n <= 0.0f);
2844
+
2845
+ if (is_empty) {
2846
+ return llama_sampler_init_empty("?top-n-sigma");
2847
+ }
2848
+
1909
2849
  return llama_sampler_init(
1910
2850
  /* .iface = */ &llama_sampler_top_n_sigma_i,
1911
2851
  /* .ctx = */ new llama_sampler_top_n_sigma {
@@ -2227,12 +3167,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
2227
3167
  }
2228
3168
 
2229
3169
  static struct llama_sampler_i llama_sampler_dry_i = {
2230
- /* .name = */ llama_sampler_dry_name,
2231
- /* .accept = */ llama_sampler_dry_accept,
2232
- /* .apply = */ llama_sampler_dry_apply,
2233
- /* .reset = */ llama_sampler_dry_reset,
2234
- /* .clone = */ llama_sampler_dry_clone,
2235
- /* .free = */ llama_sampler_dry_free,
3170
+ /* .name = */ llama_sampler_dry_name,
3171
+ /* .accept = */ llama_sampler_dry_accept,
3172
+ /* .apply = */ llama_sampler_dry_apply,
3173
+ /* .reset = */ llama_sampler_dry_reset,
3174
+ /* .clone = */ llama_sampler_dry_clone,
3175
+ /* .free = */ llama_sampler_dry_free,
3176
+ /* .backend_init = */ nullptr,
3177
+ /* .backend_accept = */ nullptr,
3178
+ /* .backend_apply = */ nullptr,
3179
+ /* .backend_set_input = */ nullptr,
2236
3180
  };
2237
3181
 
2238
3182
  struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
@@ -2243,6 +3187,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
2243
3187
 
2244
3188
  const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2245
3189
 
3190
+ if (!dry_enabled) {
3191
+ return llama_sampler_init_empty("?dry");
3192
+ }
3193
+
2246
3194
  if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2247
3195
  // Process sequence breakers
2248
3196
  for (size_t i = 0; i < num_breakers; ++i) {
@@ -2311,18 +3259,186 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
2311
3259
  return result;
2312
3260
  }
2313
3261
 
3262
+ // adaptive-p sampler state
3263
+ //
3264
+ // maintains an exponential moving average of the *ORIGINAL* probabilities
3265
+ // of selected tokens, used to compute an adapted target at each sampling step.
3266
+ //
3267
+ // see llama.h for a full description of the sampler
3268
+ //
3269
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17927
3270
+ //
3271
+ struct llama_sampler_adaptive_p {
3272
+ const float target; // target probability (0.0 - 1.0; negative = disabled)
3273
+ const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
3274
+ const uint32_t seed; // original RNG seed
3275
+ uint32_t seed_cur; // actual RNG seed
3276
+ std::mt19937 rng; // RNG state
3277
+ float weighted_sum; // sum(p_i * decay^i)
3278
+ float total_weight; // sum(decay^i), converges to 1/(1-decay)
3279
+ std::vector<float> original_probs; // pre-transform probs, cached for EMA update
3280
+ llama_token pending_token_id; // token ID of selected token
3281
+ int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs
3282
+ };
3283
+
3284
+ // adaptive probability transformation constants
3285
+ static constexpr float DISTRIBUTION_WIDTH = 0.3f;
3286
+ static constexpr float PEAK_LOGIT_VALUE = 5.0f;
3287
+ static constexpr float SHARPNESS = 10.0f;
3288
+ static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
3289
+
3290
+ static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) {
3291
+ return "adaptive-p";
3292
+ }
3293
+
3294
+ static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
3295
+ auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
3296
+
3297
+ llama_sampler_softmax_impl(cur_p, false);
3298
+
3299
+ if (ctx->target < 0.0f) {
3300
+ // at negative target values, adaptive-p is no-op
3301
+ // we simply sample from the existing distribution
3302
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
3303
+ return;
3304
+ }
3305
+
3306
+ // store the original probabilities
3307
+ ctx->original_probs.resize(cur_p->size);
3308
+ for (size_t i = 0; i < cur_p->size; ++i) {
3309
+ ctx->original_probs[i] = cur_p->data[i].p;
3310
+ }
3311
+
3312
+ // using the EMA, compute the adapted target probability for the current sampling step
3313
+ auto target = std::clamp(ctx->target, 0.0f, 1.0f);
3314
+ float adapted_target = std::clamp(
3315
+ ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight),
3316
+ 0.0f, 1.0f
3317
+ );
3318
+
3319
+ // adaptive probability transform
3320
+ //
3321
+ // quadratic near target for fine differentiation, transitioning to linear decay in the
3322
+ // tails. unbounded negative logits ensure proper suppression of far-from-target tokens
3323
+ // after the softmax.
3324
+ //
3325
+ for (size_t i = 0; i < cur_p->size; ++i) {
3326
+ if (cur_p->data[i].logit == -INFINITY) {
3327
+ // don't transform logits that are -INFINITY
3328
+ // (as masked out by e.g. min-p and top-p when using backend sampling)
3329
+ continue;
3330
+ }
3331
+ float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH);
3332
+ cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist);
3333
+ }
3334
+
3335
+ // softmax and sample from the transformed distribution
3336
+ llama_sampler_softmax_impl(cur_p, false);
3337
+ const int idx = llama_sample_dist(cur_p, ctx->rng);
3338
+ cur_p->selected = idx;
3339
+
3340
+ // store the selected token ID for acceptance later
3341
+ ctx->pending_token_id = cur_p->data[idx].id;
3342
+ ctx->pending_token_idx = idx;
3343
+ }
3344
+
3345
+ static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) {
3346
+ auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
3347
+ if (ctx->pending_token_id == token) {
3348
+ GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL);
3349
+ GGML_ASSERT(ctx->pending_token_idx != -1);
3350
+ // update EMA with the original probability of the selected token
3351
+ ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum;
3352
+ ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
3353
+ }
3354
+ ctx->pending_token_id = LLAMA_TOKEN_NULL;
3355
+ ctx->pending_token_idx = -1;
3356
+ }
3357
+
3358
+ static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
3359
+ auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
3360
+ // ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
3361
+ // original_probs is completely overwritten on every call to _apply.
3362
+ // so we only need to reset the EMA state and pending token.
3363
+ ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
3364
+ ctx->total_weight = 1.0f / (1.0f - ctx->decay);
3365
+ ctx->pending_token_id = LLAMA_TOKEN_NULL;
3366
+ ctx->pending_token_idx = -1;
3367
+ ctx->seed_cur = get_rng_seed(ctx->seed);
3368
+ ctx->rng.seed(ctx->seed_cur);
3369
+ }
3370
+
3371
+ static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
3372
+ const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx;
3373
+ auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
3374
+ auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
3375
+
3376
+ // copy everything (target, decay, seed, and RNG are already set)
3377
+ result_ctx->weighted_sum = ctx->weighted_sum;
3378
+ result_ctx->total_weight = ctx->total_weight;
3379
+ result_ctx->pending_token_id = ctx->pending_token_id;
3380
+ result_ctx->pending_token_idx = ctx->pending_token_idx;
3381
+
3382
+ return result;
3383
+ }
3384
+
3385
+ static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) {
3386
+ delete (llama_sampler_adaptive_p *) smpl->ctx;
3387
+ }
3388
+
3389
+ static struct llama_sampler_i llama_sampler_adaptive_p_i = {
3390
+ /* .name = */ llama_sampler_adaptive_p_name,
3391
+ /* .accept = */ llama_sampler_adaptive_p_accept,
3392
+ /* .apply = */ llama_sampler_adaptive_p_apply,
3393
+ /* .reset = */ llama_sampler_adaptive_p_reset,
3394
+ /* .clone = */ llama_sampler_adaptive_p_clone,
3395
+ /* .free = */ llama_sampler_adaptive_p_free,
3396
+ /* .backend_init = */ nullptr,
3397
+ /* .backend_accept = */ nullptr,
3398
+ /* .backend_apply = */ nullptr,
3399
+ /* .backend_set_input = */ nullptr,
3400
+ };
3401
+
3402
+ struct llama_sampler * llama_sampler_init_adaptive_p(
3403
+ float target,
3404
+ float decay,
3405
+ uint32_t seed
3406
+ ) {
3407
+ auto seed_cur = get_rng_seed(seed);
3408
+ float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
3409
+ return llama_sampler_init(
3410
+ /* .iface = */ &llama_sampler_adaptive_p_i,
3411
+ /* .ctx = */ new llama_sampler_adaptive_p {
3412
+ /* .target = */ target,
3413
+ /* .decay = */ clamped_decay,
3414
+ /* .seed = */ seed,
3415
+ /* .seed_cur = */ seed_cur,
3416
+ /* .rng = */ std::mt19937(seed_cur),
3417
+ /* .weighted_sum = */ target / (1.0f - clamped_decay),
3418
+ /* .total_weight = */ 1.0f / (1.0f - clamped_decay),
3419
+ /* .original_probs = */ {},
3420
+ /* .pending_token_id = */ LLAMA_TOKEN_NULL,
3421
+ /* .pending_token_idx = */ -1
3422
+ }
3423
+ );
3424
+ }
3425
+
2314
3426
  // logit-bias
2315
3427
 
2316
- struct llama_sampler_logit_bias {
3428
+ struct llama_sampler_logit_bias : public llama_sampler_backend {
2317
3429
  const int32_t n_vocab;
2318
3430
 
2319
3431
  const std::vector<llama_logit_bias> logit_bias;
2320
3432
 
2321
3433
  std::vector<llama_logit_bias> to_search;
3434
+
3435
+ struct ggml_tensor * inp_logit_bias;
3436
+ struct ggml_tensor * inp_logit_idxs;
2322
3437
  };
2323
3438
 
2324
- static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
2325
- return "logit-bias";
3439
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
3440
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
3441
+ return ctx->get_name();
2326
3442
  }
2327
3443
 
2328
3444
  static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -2367,25 +3483,110 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2367
3483
  delete (llama_sampler_logit_bias *) smpl->ctx;
2368
3484
  }
2369
3485
 
3486
+ static void llama_sampler_logit_bias_backend_apply(
3487
+ struct llama_sampler * smpl,
3488
+ struct ggml_context * ctx,
3489
+ struct ggml_cgraph * gf,
3490
+ struct llama_sampler_data * data) {
3491
+ GGML_UNUSED(gf);
3492
+ GGML_UNUSED(ctx);
3493
+
3494
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3495
+ if (sctx->logit_bias.empty()) {
3496
+ return;
3497
+ }
3498
+
3499
+ const size_t n = sctx->logit_bias.size();
3500
+
3501
+ sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n);
3502
+ ggml_set_name(sctx->inp_logit_bias, "logit_bias");
3503
+ ggml_set_input(sctx->inp_logit_bias);
3504
+
3505
+ sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n);
3506
+ ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
3507
+ ggml_set_input(sctx->inp_logit_idxs);
3508
+
3509
+ ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
3510
+
3511
+ cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
3512
+ cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
3513
+ cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
3514
+
3515
+ data->logits = ggml_add(ctx, data->logits, cur);
3516
+ }
3517
+
3518
+ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
3519
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3520
+ if (sctx->logit_bias.empty()) {
3521
+ return;
3522
+ }
3523
+
3524
+ GGML_ASSERT(sctx->inp_logit_bias != nullptr);
3525
+ GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
3526
+
3527
+ const size_t n = sctx->logit_bias.size();
3528
+
3529
+ std::vector<float> data_logit_bias(n, 0.0f);
3530
+ std::vector<int32_t> data_logit_idxs(n, 0);
3531
+ for (size_t i = 0; i < n; ++i) {
3532
+ const auto & lb = sctx->logit_bias[i];
3533
+ GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
3534
+ data_logit_bias[i] = lb.bias;
3535
+ data_logit_idxs[i] = lb.token;
3536
+ }
3537
+
3538
+ ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
3539
+ ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
3540
+ }
3541
+
3542
+ static bool llama_sampler_logit_bias_backend_init(
3543
+ struct llama_sampler * smpl,
3544
+ ggml_backend_buffer_type_t buft) {
3545
+ GGML_UNUSED(buft);
3546
+
3547
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3548
+
3549
+ sctx->init(true);
3550
+
3551
+ if (sctx->logit_bias.empty()) {
3552
+ return true;
3553
+ }
3554
+
3555
+ return true;
3556
+ }
3557
+
2370
3558
  static struct llama_sampler_i llama_sampler_logit_bias_i = {
2371
- /* .name = */ llama_sampler_logit_bias_name,
2372
- /* .accept = */ nullptr,
2373
- /* .apply = */ llama_sampler_logit_bias_apply,
2374
- /* .reset = */ nullptr,
2375
- /* .clone = */ llama_sampler_logit_bias_clone,
2376
- /* .free = */ llama_sampler_logit_bias_free,
3559
+ /* .name = */ llama_sampler_logit_bias_name,
3560
+ /* .accept = */ nullptr,
3561
+ /* .apply = */ llama_sampler_logit_bias_apply,
3562
+ /* .reset = */ nullptr,
3563
+ /* .clone = */ llama_sampler_logit_bias_clone,
3564
+ /* .free = */ llama_sampler_logit_bias_free,
3565
+ /* .backend_init = */ llama_sampler_logit_bias_backend_init,
3566
+ /* .backend_accept = */ nullptr,
3567
+ /* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
3568
+ /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
2377
3569
  };
2378
3570
 
2379
3571
  struct llama_sampler * llama_sampler_init_logit_bias(
2380
3572
  int32_t n_vocab,
2381
3573
  int32_t n_logit_bias,
2382
3574
  const llama_logit_bias * logit_bias) {
3575
+ const bool is_empty = n_logit_bias <= 0;
3576
+
3577
+ if (is_empty) {
3578
+ return llama_sampler_init_empty("?logit-bias");
3579
+ }
3580
+
2383
3581
  return llama_sampler_init(
2384
3582
  /* .iface = */ &llama_sampler_logit_bias_i,
2385
3583
  /* .ctx = */ new llama_sampler_logit_bias {
2386
- /* .n_vocab = */ n_vocab,
2387
- /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2388
- /* .to_search = */ {},
3584
+ ("logit-bias"),
3585
+ /* .n_vocab = */ n_vocab,
3586
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
3587
+ /* .to_search = */ {},
3588
+ /* .inp_logit_bias = */ nullptr,
3589
+ /* .inp_logit_idxs = */ nullptr,
2389
3590
  }
2390
3591
  );
2391
3592
  }
@@ -2541,8 +3742,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2541
3742
  if (n_non_eog == 0) {
2542
3743
  cur_p->size = 1;
2543
3744
  cur_p->data[0].id = ctx->vocab->token_eot();
3745
+ if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
3746
+ cur_p->data[0].id = ctx->vocab->token_eos();
3747
+ }
2544
3748
  cur_p->data[0].logit = 1.0f;
2545
3749
 
3750
+ GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
3751
+
2546
3752
  return;
2547
3753
  }
2548
3754
 
@@ -2593,12 +3799,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2593
3799
  }
2594
3800
 
2595
3801
  static struct llama_sampler_i llama_sampler_infill_i = {
2596
- /* .name = */ llama_sampler_infill_name,
2597
- /* .accept = */ nullptr,
2598
- /* .apply = */ llama_sampler_infill_apply,
2599
- /* .reset = */ nullptr,
2600
- /* .clone = */ llama_sampler_infill_clone,
2601
- /* .free = */ llama_sampler_infill_free,
3802
+ /* .name = */ llama_sampler_infill_name,
3803
+ /* .accept = */ nullptr,
3804
+ /* .apply = */ llama_sampler_infill_apply,
3805
+ /* .reset = */ nullptr,
3806
+ /* .clone = */ llama_sampler_infill_clone,
3807
+ /* .free = */ llama_sampler_infill_free,
3808
+ /* .backend_apply = */ nullptr,
3809
+ /* .backend_accept = */ nullptr,
3810
+ /* .backend_set_input = */ nullptr,
3811
+ /* .backend_init = */ nullptr,
2602
3812
  };
2603
3813
 
2604
3814
  struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
@@ -2630,7 +3840,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2630
3840
  if (smpl->iface == &llama_sampler_chain_i) {
2631
3841
  const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2632
3842
  for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2633
- const uint32_t seed = llama_sampler_get_seed(*it);
3843
+ const uint32_t seed = llama_sampler_get_seed(it->ptr);
2634
3844
  if (seed != LLAMA_DEFAULT_SEED) {
2635
3845
  return seed;
2636
3846
  }
@@ -2660,8 +3870,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
2660
3870
  void llama_perf_sampler_print(const struct llama_sampler * chain) {
2661
3871
  const auto data = llama_perf_sampler(chain);
2662
3872
 
2663
- LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2664
- __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
3873
+ LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
2665
3874
  }
2666
3875
 
2667
3876
  void llama_perf_sampler_reset(struct llama_sampler * chain) {
@@ -2671,5 +3880,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
2671
3880
 
2672
3881
  auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2673
3882
 
2674
- ctx->t_sample_us = ctx->n_sample = 0;
3883
+ ctx->t_sample_us = 0;
3884
+ ctx->n_sample = 0;
2675
3885
  }