whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -4,6 +4,9 @@
4
4
  #include "llama-vocab.h"
5
5
  #include "llama-grammar.h"
6
6
 
7
+ #include "ggml-cpp.h"
8
+
9
+ #include <array>
7
10
  #include <algorithm>
8
11
  #include <cassert>
9
12
  #include <cfloat>
@@ -345,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
345
348
 
346
349
  // llama_sampler API
347
350
 
348
- struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
351
+ struct llama_sampler * llama_sampler_init(
352
+ struct llama_sampler_i * iface,
353
+ llama_sampler_context_t ctx) {
349
354
  return new llama_sampler {
350
355
  /* .iface = */ iface,
351
356
  /* .ctx = */ ctx,
@@ -361,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
361
366
  }
362
367
 
363
368
  void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
369
+ if (!smpl) {
370
+ return;
371
+ }
372
+
364
373
  if (smpl->iface->accept) {
365
374
  smpl->iface->accept(smpl, token);
366
375
  }
367
376
  }
368
377
 
369
378
  void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
379
+ if (!smpl) {
380
+ return;
381
+ }
382
+
370
383
  GGML_ASSERT(smpl->iface->apply);
371
384
  smpl->iface->apply(smpl, cur_p);
372
385
  }
373
386
 
374
387
  void llama_sampler_reset(struct llama_sampler * smpl) {
388
+ if (!smpl) {
389
+ return;
390
+ }
391
+
375
392
  if (smpl->iface->reset) {
376
393
  smpl->iface->reset(smpl);
377
394
  }
378
395
  }
379
396
 
380
397
  struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
398
+ if (!smpl) {
399
+ return nullptr;
400
+ }
401
+
381
402
  if (smpl->iface->clone) {
382
403
  return smpl->iface->clone(smpl);
383
404
  }
@@ -404,37 +425,200 @@ void llama_sampler_free(struct llama_sampler * smpl) {
404
425
  delete smpl;
405
426
  }
406
427
 
407
- llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
408
- const auto * logits = llama_get_logits_ith(ctx, idx);
428
+ // empty sampler
409
429
 
410
- const llama_model * model = llama_get_model(ctx);
411
- const llama_vocab * vocab = llama_model_get_vocab(model);
430
+ struct llama_sampler_empty {
431
+ const char * name;
432
+ };
412
433
 
413
- const int n_vocab = llama_vocab_n_tokens(vocab);
434
+ static struct llama_sampler * llama_sampler_init_empty(const char * name);
435
+
436
+ static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
437
+ auto * ctx = (llama_sampler_empty *) smpl->ctx;
438
+ return ctx->name;
439
+ }
440
+
441
+ static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
442
+ GGML_UNUSED(smpl);
443
+ GGML_UNUSED(token);
444
+ }
445
+
446
+ static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
447
+ GGML_UNUSED(smpl);
448
+ GGML_UNUSED(cur_p);
449
+ }
450
+
451
+ static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
452
+ GGML_UNUSED(smpl);
453
+ }
454
+
455
+ static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
456
+ auto * ctx = (llama_sampler_empty *) smpl->ctx;
457
+ return llama_sampler_init_empty(ctx->name);
458
+ }
414
459
 
415
- // TODO: do not allocate each time
416
- std::vector<llama_token_data> cur;
417
- cur.reserve(n_vocab);
418
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
419
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
460
+ static void llama_sampler_empty_free(struct llama_sampler * smpl) {
461
+ delete (llama_sampler_empty *) smpl->ctx;
462
+ }
463
+
464
+ static bool llama_sampler_empty_backend_init(
465
+ struct llama_sampler * smpl,
466
+ ggml_backend_buffer_type_t buft) {
467
+ GGML_UNUSED(smpl);
468
+ GGML_UNUSED(buft);
469
+
470
+ return true;
471
+ }
472
+
473
+ static void llama_sampler_empty_backend_accept(
474
+ struct llama_sampler * smpl,
475
+ ggml_context * ctx,
476
+ ggml_cgraph * gf,
477
+ struct ggml_tensor * selected_token) {
478
+ GGML_UNUSED(smpl);
479
+ GGML_UNUSED(ctx);
480
+ GGML_UNUSED(gf);
481
+ GGML_UNUSED(selected_token);
482
+ }
483
+
484
+ static void llama_sampler_empty_backend_apply(
485
+ struct llama_sampler * smpl,
486
+ struct ggml_context * ctx,
487
+ struct ggml_cgraph * gf,
488
+ struct llama_sampler_data * data) {
489
+ GGML_UNUSED(smpl);
490
+ GGML_UNUSED(ctx);
491
+ GGML_UNUSED(gf);
492
+ GGML_UNUSED(data);
493
+ }
494
+
495
+ static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
496
+ GGML_UNUSED(smpl);
497
+ }
498
+
499
+ static struct llama_sampler_i llama_sampler_empty_i = {
500
+ /* .name = */ llama_sampler_empty_name,
501
+ /* .accept = */ llama_sampler_empty_accept,
502
+ /* .apply = */ llama_sampler_empty_apply,
503
+ /* .reset = */ llama_sampler_empty_reset,
504
+ /* .clone = */ llama_sampler_empty_clone,
505
+ /* .free = */ llama_sampler_empty_free,
506
+ /* .backend_init = */ llama_sampler_empty_backend_init,
507
+ /* .backend_accept = */ llama_sampler_empty_backend_accept,
508
+ /* .backend_apply = */ llama_sampler_empty_backend_apply,
509
+ /* .backend_set_input = */ llama_sampler_empty_backend_set_input,
510
+ };
511
+
512
+ struct llama_sampler * llama_sampler_init_empty(const char * name) {
513
+ return llama_sampler_init(
514
+ /* .iface = */ &llama_sampler_empty_i,
515
+ /* .ctx = */ new llama_sampler_empty {
516
+ /* .name = */ name,
517
+ }
518
+ );
519
+ }
520
+
521
+ // common backend sampler functionality
522
+ //
523
+ // +name : means that the sampler is support and will run on the backend
524
+ // -name : means that a ggml operator is not supported by the backend
525
+ //
526
+ struct llama_sampler_backend {
527
+ llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
528
+
529
+ const char * get_name() {
530
+ if (!is_init) {
531
+ return name.c_str();
532
+ }
533
+
534
+ if (support) {
535
+ name_ext = "+" + name;
536
+ } else {
537
+ name_ext = "-" + name;
538
+ }
539
+
540
+ return name_ext.c_str();
420
541
  }
421
542
 
422
- llama_token_data_array cur_p = {
423
- /* .data = */ cur.data(),
424
- /* .size = */ cur.size(),
425
- /* .selected = */ -1,
426
- /* .sorted = */ false,
543
+ void init(bool support) {
544
+ GGML_ASSERT(this->is_init == false);
545
+
546
+ this->is_init = true;
547
+ this->support = support;
548
+ }
549
+
550
+ private:
551
+ std::string name;
552
+ std::string name_ext;
553
+
554
+ bool is_init;
555
+ bool support;
556
+ };
557
+
558
+ // check if all ggml ops used by the sampler are supported by the backend
559
+ static bool llama_sampler_backend_support(
560
+ llama_sampler * smpl,
561
+ ggml_backend_buffer_type_t buft) {
562
+ auto * device = ggml_backend_buft_get_device(buft);
563
+ if (!device) {
564
+ // CPU backend always supported
565
+ return true;
566
+ }
567
+
568
+ ggml_init_params params = {
569
+ /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
570
+ /*.mem_buffer =*/ NULL,
571
+ /*.no_alloc =*/ true,
427
572
  };
428
573
 
429
- llama_sampler_apply(smpl, &cur_p);
574
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
575
+ if (!ctx_ptr) {
576
+ throw std::runtime_error(format("failed to create ggml context"));
577
+ }
430
578
 
431
- GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
579
+ ggml_context * ctx = ctx_ptr.get();
432
580
 
433
- auto token = cur_p.data[cur_p.selected].id;
581
+ const int64_t n = 1024*1024;
434
582
 
435
- llama_sampler_accept(smpl, token);
583
+ llama_sampler_data data = {
584
+ /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
585
+ /*.probs = */ nullptr,
586
+ /*.sampled = */ nullptr,
587
+ /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
588
+ };
436
589
 
437
- return token;
590
+ ggml_cgraph * gf = ggml_new_graph(ctx);
591
+
592
+ smpl->iface->backend_apply(smpl, ctx, gf, &data);
593
+
594
+ if (data.logits) {
595
+ ggml_build_forward_expand(gf, data.logits);
596
+ }
597
+
598
+ if (data.probs) {
599
+ ggml_build_forward_expand(gf, data.probs);
600
+ }
601
+
602
+ if (data.sampled) {
603
+ ggml_build_forward_expand(gf, data.sampled);
604
+ }
605
+
606
+ if (data.candidates) {
607
+ ggml_build_forward_expand(gf, data.candidates);
608
+ }
609
+
610
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
611
+ struct ggml_tensor * op = ggml_graph_node(gf, i);
612
+
613
+ if (!ggml_backend_dev_supports_op(device, op)) {
614
+ LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
615
+ __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
616
+
617
+ return false;
618
+ }
619
+ }
620
+
621
+ return true;
438
622
  }
439
623
 
440
624
  // sampler chain
@@ -448,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
448
632
 
449
633
  time_meas tm(chain->t_sample_us, chain->params.no_perf);
450
634
 
451
- for (auto * smpl : chain->samplers) {
452
- llama_sampler_accept(smpl, token);
635
+ for (auto & smpl : chain->samplers) {
636
+ llama_sampler_accept(smpl.ptr, token);
453
637
  }
454
638
 
455
639
  chain->n_sample++;
@@ -460,20 +644,29 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
460
644
 
461
645
  time_meas tm(chain->t_sample_us, chain->params.no_perf);
462
646
 
463
- for (auto * smpl : chain->samplers) {
464
- llama_sampler_apply(smpl, cur_p);
647
+ bool is_backend = chain->is_init;
648
+
649
+ for (auto & smpl : chain->samplers) {
650
+ if (is_backend && smpl.is_backend) {
651
+ continue;
652
+ }
653
+
654
+ is_backend = false;
655
+
656
+ if (smpl.ptr->iface->apply == nullptr) {
657
+ continue;
658
+ }
659
+
660
+ llama_sampler_apply(smpl.ptr, cur_p);
465
661
  }
466
662
  }
467
663
 
468
664
  static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
469
665
  auto * chain = (llama_sampler_chain *) smpl->ctx;
470
666
 
471
- for (auto * smpl : chain->samplers) {
472
- llama_sampler_reset(smpl);
667
+ for (auto & smpl : chain->samplers) {
668
+ llama_sampler_reset(smpl.ptr);
473
669
  }
474
-
475
- chain->t_sample_us = 0;
476
- chain->n_sample = 0;
477
670
  }
478
671
 
479
672
  static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
@@ -481,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
481
674
 
482
675
  auto * result = llama_sampler_chain_init(chain_src->params);
483
676
 
484
- for (auto * smpl : chain_src->samplers) {
485
- llama_sampler_chain_add(result, llama_sampler_clone(smpl));
677
+ for (const auto & smpl : chain_src->samplers) {
678
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
486
679
  }
487
680
 
488
681
  return result;
@@ -491,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
491
684
  static void llama_sampler_chain_free(struct llama_sampler * smpl) {
492
685
  auto * chain = (llama_sampler_chain *) smpl->ctx;
493
686
 
494
- for (auto * smpl : chain->samplers) {
495
- llama_sampler_free(smpl);
687
+ for (auto & smpl : chain->samplers) {
688
+ llama_sampler_free(smpl.ptr);
496
689
  }
497
690
 
498
691
  delete chain;
499
692
  }
500
693
 
694
+ static bool llama_sampler_chain_backend_init(
695
+ struct llama_sampler * smpl,
696
+ ggml_backend_buffer_type_t buft) {
697
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
698
+
699
+ GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
700
+
701
+ chain->is_init = true;
702
+
703
+ bool res = true;
704
+
705
+ for (auto & smpl : chain->samplers) {
706
+ bool res_cur = true;
707
+
708
+ // to be able to run a sampler on the backend, it has to:
709
+ // - have the .backend_init() API implemented
710
+ // - return true during .backend_init()
711
+ if (smpl.ptr->iface->backend_init) {
712
+ if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
713
+ res_cur = false;
714
+ }
715
+ } else {
716
+ res_cur = false;
717
+ }
718
+
719
+ smpl.is_backend = res_cur;
720
+
721
+ res = res && res_cur;
722
+ }
723
+
724
+ return res;
725
+ }
726
+
727
+ static void llama_sampler_chain_backend_accept(
728
+ struct llama_sampler * smpl,
729
+ ggml_context * ctx,
730
+ ggml_cgraph * gf,
731
+ struct ggml_tensor * selected_token) {
732
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
733
+
734
+ for (auto & smpl : chain->samplers) {
735
+ if (!smpl.is_backend) {
736
+ break;
737
+ }
738
+
739
+ if (smpl.ptr->iface->backend_accept) {
740
+ smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
741
+ }
742
+ }
743
+ }
744
+
745
+ static void llama_sampler_chain_backend_apply(
746
+ struct llama_sampler * smpl,
747
+ struct ggml_context * ctx,
748
+ struct ggml_cgraph * gf,
749
+ struct llama_sampler_data * data) {
750
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
751
+
752
+ GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
753
+
754
+ for (auto & smpl : chain->samplers) {
755
+ if (!smpl.is_backend) {
756
+ break;
757
+ }
758
+
759
+ if (smpl.ptr->iface->backend_apply) {
760
+ smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
761
+ }
762
+ }
763
+ }
764
+
765
+ static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
766
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
767
+
768
+ for (auto & smpl : chain->samplers) {
769
+ if (!smpl.is_backend) {
770
+ break;
771
+ }
772
+
773
+ if (smpl.ptr->iface->backend_set_input) {
774
+ smpl.ptr->iface->backend_set_input(smpl.ptr);
775
+ }
776
+ }
777
+ }
778
+
501
779
  static struct llama_sampler_i llama_sampler_chain_i = {
502
- /* .name = */ llama_sampler_chain_name,
503
- /* .accept = */ llama_sampler_chain_accept,
504
- /* .apply = */ llama_sampler_chain_apply,
505
- /* .reset = */ llama_sampler_chain_reset,
506
- /* .clone = */ llama_sampler_chain_clone,
507
- /* .free = */ llama_sampler_chain_free,
780
+ /* .name = */ llama_sampler_chain_name,
781
+ /* .accept = */ llama_sampler_chain_accept,
782
+ /* .apply = */ llama_sampler_chain_apply,
783
+ /* .reset = */ llama_sampler_chain_reset,
784
+ /* .clone = */ llama_sampler_chain_clone,
785
+ /* .free = */ llama_sampler_chain_free,
786
+ /* .backend_init = */ llama_sampler_chain_backend_init,
787
+ /* .backend_accept = */ llama_sampler_chain_backend_accept,
788
+ /* .backend_apply = */ llama_sampler_chain_backend_apply,
789
+ /* .backend_set_input = */ llama_sampler_chain_backend_set_input,
508
790
  };
509
791
 
510
792
  struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@@ -512,26 +794,113 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
512
794
  /* .iface = */ &llama_sampler_chain_i,
513
795
  /* .ctx = */ new llama_sampler_chain {
514
796
  /* .params = */ params,
797
+ /* .is_init = */ false,
515
798
  /* .samplers = */ {},
799
+ /* .cur = */ {},
516
800
  /* .t_sample_us = */ 0,
517
801
  /* .n_sample = */ 0,
518
802
  }
519
803
  );
520
804
  }
521
805
 
806
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
807
+ const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
808
+ const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
809
+ const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
810
+ const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
811
+
812
+ // If a backend sampler has already sampled a token, return it.
813
+ if (sampled_token != LLAMA_TOKEN_NULL) {
814
+ LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
815
+ return sampled_token;
816
+ }
817
+
818
+ const llama_model * model = llama_get_model(ctx);
819
+ const llama_vocab * vocab = llama_model_get_vocab(model);
820
+
821
+ const int n_vocab = llama_vocab_n_tokens(vocab);
822
+
823
+ // use pre-allocated buffer from chain if available, otherwise allocate locally
824
+ std::vector<llama_token_data> * cur_ptr;
825
+ std::vector<llama_token_data> cur_local;
826
+
827
+ if (smpl->iface == &llama_sampler_chain_i) {
828
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
829
+ cur_ptr = &chain->cur;
830
+ } else {
831
+ cur_ptr = &cur_local;
832
+ }
833
+
834
+ auto & cur = *cur_ptr;
835
+
836
+ if (sampled_probs) {
837
+ const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
838
+ cur.resize(sampled_probs_count);
839
+ for (uint32_t i = 0; i < sampled_probs_count; ++i) {
840
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
841
+ }
842
+ } else if (sampled_logits) {
843
+ const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
844
+ cur.resize(sampled_logits_count);
845
+ for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
846
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
847
+ }
848
+ } else {
849
+ const auto * logits = llama_get_logits_ith(ctx, idx);
850
+ GGML_ASSERT(logits != nullptr);
851
+ cur.resize(n_vocab);
852
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
853
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
854
+ }
855
+ }
856
+
857
+ llama_token_data_array cur_p = {
858
+ /* .data = */ cur.data(),
859
+ /* .size = */ cur.size(),
860
+ /* .selected = */ -1,
861
+ /* .sorted = */ false,
862
+ };
863
+
864
+ llama_sampler_apply(smpl, &cur_p);
865
+
866
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
867
+
868
+ auto token = cur_p.data[cur_p.selected].id;
869
+
870
+ llama_sampler_accept(smpl, token);
871
+
872
+ return token;
873
+ }
874
+
875
+
522
876
  void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
523
877
  auto * p = (llama_sampler_chain *) chain->ctx;
524
- p->samplers.push_back(smpl);
878
+ p->samplers.push_back({
879
+ /* .is_backend = */ false,
880
+ /* .ptr = */ smpl,
881
+ });
525
882
  }
526
883
 
527
- struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
884
+ struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
885
+ if (chain == nullptr) {
886
+ return nullptr;
887
+ }
888
+
889
+ if (chain->iface != &llama_sampler_chain_i) {
890
+ return nullptr;
891
+ }
892
+
893
+ if (i == -1) {
894
+ return chain;
895
+ }
896
+
528
897
  const auto * p = (const llama_sampler_chain *) chain->ctx;
529
898
 
530
899
  if (i < 0 || (size_t) i >= p->samplers.size()) {
531
900
  return nullptr;
532
901
  }
533
902
 
534
- return p->samplers[i];
903
+ return p->samplers[i].ptr;
535
904
  }
536
905
 
537
906
  struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
@@ -541,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
541
910
  return nullptr;
542
911
  }
543
912
 
544
- auto * result = p->samplers[i];
913
+ auto * result = p->samplers[i].ptr;
545
914
  p->samplers.erase(p->samplers.begin() + i);
546
915
 
547
916
  return result;
@@ -559,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
559
928
 
560
929
  // greedy
561
930
 
562
- static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
563
- return "greedy";
931
+ struct llama_sampler_greedy : public llama_sampler_backend {
932
+ };
933
+
934
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
935
+ auto * sctx = (llama_sampler_greedy *) smpl->ctx;
936
+ return sctx->get_name();
937
+ }
938
+
939
+ static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
940
+ auto * ctx = (llama_sampler_greedy *) smpl->ctx;
941
+ GGML_UNUSED(ctx);
942
+ }
943
+
944
+ static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
945
+ const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
946
+ auto * result = llama_sampler_init_greedy();
947
+
948
+ // copy the state
949
+ {
950
+ auto * result_ctx = (llama_sampler_greedy *) result->ctx;
951
+
952
+ GGML_UNUSED(ctx);
953
+ GGML_UNUSED(result_ctx);
954
+ }
955
+
956
+ return result;
957
+ }
958
+
959
+ static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
960
+ delete (llama_sampler_greedy *) smpl->ctx;
564
961
  }
565
962
 
566
963
  static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
@@ -572,33 +969,72 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
572
969
  }
573
970
  }
574
971
 
972
+ static bool llama_sampler_greedy_backend_init(
973
+ struct llama_sampler * smpl,
974
+ ggml_backend_buffer_type_t buft) {
975
+ auto * sctx = (llama_sampler_greedy *) smpl->ctx;
976
+
977
+ const bool res = llama_sampler_backend_support(smpl, buft);
978
+
979
+ sctx->init(res);
980
+
981
+ return res;
982
+ }
983
+
984
+ static void llama_sampler_greedy_backend_apply(
985
+ struct llama_sampler * smpl,
986
+ struct ggml_context * ctx,
987
+ struct ggml_cgraph * gf,
988
+ struct llama_sampler_data * data) {
989
+ GGML_UNUSED(gf);
990
+ GGML_UNUSED(smpl);
991
+
992
+ struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
993
+ ggml_set_name(curl, "greedy_argmax");
994
+
995
+ data->sampled = curl;
996
+ }
997
+
575
998
  static struct llama_sampler_i llama_sampler_greedy_i = {
576
- /* .name = */ llama_sampler_greedy_name,
577
- /* .accept = */ nullptr,
578
- /* .apply = */ llama_sampler_greedy_apply,
579
- /* .reset = */ nullptr,
580
- /* .clone = */ nullptr,
581
- /* .free = */ nullptr,
999
+ /* .name = */ llama_sampler_greedy_name,
1000
+ /* .accept = */ nullptr,
1001
+ /* .apply = */ llama_sampler_greedy_apply,
1002
+ /* .reset = */ llama_sampler_greedy_reset,
1003
+ /* .clone = */ llama_sampler_greedy_clone,
1004
+ /* .free = */ llama_sampler_greedy_free,
1005
+ /* .backend_init = */ llama_sampler_greedy_backend_init,
1006
+ /* .backend_accept = */ nullptr,
1007
+ /* .backend_apply = */ llama_sampler_greedy_backend_apply,
1008
+ /* .backend_set_input = */ nullptr,
582
1009
  };
583
1010
 
584
1011
  struct llama_sampler * llama_sampler_init_greedy() {
585
1012
  return llama_sampler_init(
586
1013
  /* .iface = */ &llama_sampler_greedy_i,
587
- /* .ctx = */ nullptr
1014
+ /* .ctx = */ new llama_sampler_greedy {
1015
+ ("greedy"),
1016
+ }
588
1017
  );
589
1018
  }
590
1019
 
591
1020
  // dist
592
1021
 
593
- struct llama_sampler_dist {
1022
+ struct llama_sampler_dist : public llama_sampler_backend {
594
1023
  const uint32_t seed;
595
1024
  uint32_t seed_cur;
596
1025
 
597
1026
  std::mt19937 rng;
1027
+
1028
+ // backend input
1029
+ struct ggml_tensor * inp_uniform;
1030
+
1031
+ ggml_context_ptr inp_ctx;
1032
+ ggml_backend_buffer_ptr inp_buf;
598
1033
  };
599
1034
 
600
- static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
601
- return "dist";
1035
+ static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
1036
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1037
+ return sctx->get_name();
602
1038
  }
603
1039
 
604
1040
  static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -673,6 +1109,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
673
1109
  #endif
674
1110
  }
675
1111
 
1112
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
1113
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
1114
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1115
+ ctx->rng.seed(ctx->seed_cur);
1116
+ }
1117
+
676
1118
  static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
677
1119
  const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
678
1120
  auto * result = llama_sampler_init_dist(ctx->seed);
@@ -687,23 +1129,127 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
687
1129
  return result;
688
1130
  }
689
1131
 
690
- static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
691
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
692
- ctx->seed_cur = get_rng_seed(ctx->seed);
693
- ctx->rng.seed(ctx->seed_cur);
694
- }
695
-
696
1132
  static void llama_sampler_dist_free(struct llama_sampler * smpl) {
697
1133
  delete (llama_sampler_dist *) smpl->ctx;
698
1134
  }
699
1135
 
1136
+ static bool llama_sampler_dist_backend_init(
1137
+ struct llama_sampler * smpl,
1138
+ ggml_backend_buffer_type_t buft) {
1139
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1140
+
1141
+ // allocate inputs
1142
+ {
1143
+ ggml_init_params params = {
1144
+ /*.mem_size =*/ ggml_tensor_overhead(),
1145
+ /*.mem_buffer =*/ nullptr,
1146
+ /*.no_alloc =*/ true,
1147
+ };
1148
+
1149
+ sctx->inp_ctx.reset(ggml_init(params));
1150
+
1151
+ // Create the uniform random scalar input tensor. This will be set by
1152
+ // llama_sampler_dist_backend_set_input after this graph is built.
1153
+ sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
1154
+ ggml_set_name (sctx->inp_uniform, "uniform");
1155
+ ggml_set_input(sctx->inp_uniform);
1156
+
1157
+ // Allocate all tensors from our context to the backend
1158
+ sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
1159
+
1160
+ ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
1161
+ }
1162
+
1163
+ const bool res = llama_sampler_backend_support(smpl, buft);
1164
+
1165
+ sctx->init(res);
1166
+
1167
+ if (!res) {
1168
+ sctx->inp_ctx.reset(nullptr);
1169
+ sctx->inp_buf.reset(nullptr);
1170
+ }
1171
+
1172
+ return res;
1173
+ }
1174
+
1175
+ static void llama_sampler_dist_backend_apply(
1176
+ struct llama_sampler * smpl,
1177
+ struct ggml_context * ctx,
1178
+ struct ggml_cgraph * gf,
1179
+ struct llama_sampler_data * data) {
1180
+ GGML_UNUSED(gf);
1181
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1182
+
1183
+ struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
1184
+ ggml_set_name(probs, "dist_probs");
1185
+
1186
+ struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
1187
+ ggml_set_name(cumsum, "dist_cumsum");
1188
+
1189
+ // The uniform tensor has a random value and we subtract this tensor with
1190
+ // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
1191
+ // Recall that each entry in cumsum is the cumulative probability up to that
1192
+ // index so values stay negative while the cumulative total is below the
1193
+ // random value, and become zero/positive once the threshold is crossed.
1194
+ struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
1195
+ ggml_set_name(diff, "dist_cumsum");
1196
+
1197
+ // The ggml_step function produces a tensor where entries are 1 if the
1198
+ // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
1199
+ // the index where the cumulative probability exceeds the random value are 0,
1200
+ // and all entries after that are 1.
1201
+ struct ggml_tensor * mask = ggml_step(ctx, diff);
1202
+ ggml_set_name(mask, "dist_mask");
1203
+
1204
+ // Taking the sum of the mask gives us the sum of elements after the threshold
1205
+ // we are interested in.
1206
+ struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1207
+ ggml_set_name(idxf, "dist_index_f32");
1208
+
1209
+ // Use ggml_scale_bias to scale the index value by -1 and then add the size
1210
+ // of the mask to that value so we get the correct index ((-1 * idxf) + n).
1211
+ struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
1212
+ ggml_set_name(idx, "dist_index_i32");
1213
+
1214
+ // Map back to original vocab ids if a candidates tensor is available.
1215
+ struct ggml_tensor * sampled_token = idx;
1216
+ if (data->candidates != nullptr) {
1217
+ struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
1218
+
1219
+ sampled_token = ggml_get_rows(ctx, candidates, idx);
1220
+ ggml_set_name(sampled_token, "dist_sampled_token");
1221
+ }
1222
+
1223
+ data->sampled = sampled_token;
1224
+ data->probs = probs;
1225
+ }
1226
+
1227
+ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
1228
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1229
+ GGML_ASSERT(sctx->inp_uniform != nullptr);
1230
+
1231
+ // We sample in double precision and cast to float to match rnd numbers of
1232
+ // llama_dampler_dist which uses double precision (sampling from
1233
+ // std::uniform_real_distribution<double> and
1234
+ // std::uniform_real_distribution<float> with same rng will produce
1235
+ // different sequences).
1236
+ std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1237
+ const float rnd = dist(sctx->rng);
1238
+
1239
+ ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
1240
+ }
1241
+
700
1242
  static struct llama_sampler_i llama_sampler_dist_i = {
701
- /* .name = */ llama_sampler_dist_name,
702
- /* .accept = */ nullptr,
703
- /* .apply = */ llama_sampler_dist_apply,
704
- /* .reset = */ llama_sampler_dist_reset,
705
- /* .clone = */ llama_sampler_dist_clone,
706
- /* .free = */ llama_sampler_dist_free,
1243
+ /* .name = */ llama_sampler_dist_name,
1244
+ /* .accept = */ nullptr,
1245
+ /* .apply = */ llama_sampler_dist_apply,
1246
+ /* .reset = */ llama_sampler_dist_reset,
1247
+ /* .clone = */ llama_sampler_dist_clone,
1248
+ /* .free = */ llama_sampler_dist_free,
1249
+ /* .backend_init = */ llama_sampler_dist_backend_init,
1250
+ /* .backend_accept = */ nullptr,
1251
+ /* .backend_apply = */ llama_sampler_dist_backend_apply,
1252
+ /* .backend_set_input = */ llama_sampler_dist_backend_set_input,
707
1253
  };
708
1254
 
709
1255
  struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
@@ -711,21 +1257,26 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
711
1257
  return llama_sampler_init(
712
1258
  /* .iface = */ &llama_sampler_dist_i,
713
1259
  /* .ctx = */ new llama_sampler_dist {
714
- /* .seed = */ seed,
715
- /* .seed_cur = */ seed_cur,
716
- /* .rng = */ std::mt19937(seed_cur),
1260
+ ("dist"),
1261
+ /* .seed = */ seed,
1262
+ /* .seed_cur = */ seed_cur,
1263
+ /* .rng = */ std::mt19937(seed_cur),
1264
+ /* .inp_uniform = */ nullptr,
1265
+ /* .inp_ctx = */ nullptr,
1266
+ /* .inp_buf = */ nullptr,
717
1267
  }
718
1268
  );
719
1269
  }
720
1270
 
721
1271
  // top-k
722
1272
 
723
- struct llama_sampler_top_k {
1273
+ struct llama_sampler_top_k : public llama_sampler_backend {
724
1274
  const int32_t k;
725
1275
  };
726
1276
 
727
- static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
728
- return "top-k";
1277
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
1278
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1279
+ return sctx->get_name();
729
1280
  }
730
1281
 
731
1282
  static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -742,19 +1293,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
742
1293
  delete (llama_sampler_top_k *) smpl->ctx;
743
1294
  }
744
1295
 
1296
+ static bool llama_sampler_top_k_backend_init(
1297
+ struct llama_sampler * smpl,
1298
+ ggml_backend_buffer_type_t buft) {
1299
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1300
+
1301
+ const bool res = llama_sampler_backend_support(smpl, buft);
1302
+
1303
+ sctx->init(res);
1304
+
1305
+ return res;
1306
+ }
1307
+
1308
+ static void llama_sampler_top_k_backend_apply(
1309
+ struct llama_sampler * smpl,
1310
+ struct ggml_context * ctx,
1311
+ struct ggml_cgraph * gf,
1312
+ struct llama_sampler_data * data) {
1313
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1314
+
1315
+ struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
1316
+ ggml_set_name(top_k, "top_k");
1317
+
1318
+ if (data->candidates) {
1319
+ struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1320
+ data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
1321
+ data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
1322
+ ggml_set_name(data->candidates, "top_k_candidates");
1323
+ } else {
1324
+ data->candidates = top_k;
1325
+ }
1326
+
1327
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1328
+ struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
1329
+ data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
1330
+ ggml_set_name(top_k_rows, "top_k_rows");
1331
+
1332
+ GGML_UNUSED(gf);
1333
+ }
1334
+
745
1335
  static struct llama_sampler_i llama_sampler_top_k_i = {
746
- /* .name = */ llama_sampler_top_k_name,
747
- /* .accept = */ nullptr,
748
- /* .apply = */ llama_sampler_top_k_apply,
749
- /* .reset = */ nullptr,
750
- /* .clone = */ llama_sampler_top_k_clone,
751
- /* .free = */ llama_sampler_top_k_free,
1336
+ /* .name = */ llama_sampler_top_k_name,
1337
+ /* .accept = */ nullptr,
1338
+ /* .apply = */ llama_sampler_top_k_apply,
1339
+ /* .reset = */ nullptr,
1340
+ /* .clone = */ llama_sampler_top_k_clone,
1341
+ /* .free = */ llama_sampler_top_k_free,
1342
+ /* .backend_init = */ llama_sampler_top_k_backend_init,
1343
+ /* .backend_accept = */ nullptr,
1344
+ /* .backend_apply = */ llama_sampler_top_k_backend_apply,
1345
+ /* .backend_set_input = */ nullptr,
752
1346
  };
753
1347
 
754
1348
  struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
1349
+ const bool is_empty = (k <= 0);
1350
+
1351
+ if (is_empty) {
1352
+ return llama_sampler_init_empty("?top-k");
1353
+ }
1354
+
755
1355
  return llama_sampler_init(
756
1356
  /* .iface = */ &llama_sampler_top_k_i,
757
1357
  /* .ctx = */ new llama_sampler_top_k {
1358
+ ("top-k"),
758
1359
  /* .k = */ k,
759
1360
  }
760
1361
  );
@@ -762,15 +1363,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
762
1363
 
763
1364
  // top-p
764
1365
 
765
- struct llama_sampler_top_p {
1366
+ struct llama_sampler_top_p : public llama_sampler_backend {
766
1367
  const float p;
767
1368
  const size_t min_keep;
768
1369
 
769
1370
  std::vector<llama_token_data> buf_sort;
770
1371
  };
771
1372
 
772
- static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
773
- return "top-p";
1373
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
1374
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1375
+ return sctx->get_name();
774
1376
  }
775
1377
 
776
1378
  static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -837,19 +1439,118 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
837
1439
  delete (llama_sampler_top_p *) smpl->ctx;
838
1440
  }
839
1441
 
1442
+ static bool llama_sampler_top_p_backend_init(
1443
+ struct llama_sampler * smpl,
1444
+ ggml_backend_buffer_type_t buft) {
1445
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1446
+
1447
+ const bool res = llama_sampler_backend_support(smpl, buft);
1448
+
1449
+ sctx->init(res);
1450
+
1451
+ return res;
1452
+ }
1453
+
1454
+ static void llama_sampler_top_p_backend_apply(
1455
+ struct llama_sampler * smpl,
1456
+ struct ggml_context * ctx,
1457
+ struct ggml_cgraph * gf,
1458
+ struct llama_sampler_data * data) {
1459
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1460
+
1461
+ auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
1462
+ GGML_ASSERT(ggml_nrows(a) == 1);
1463
+ struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
1464
+ struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
1465
+ return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
1466
+ };
1467
+
1468
+ // Get the sorted logits in descending order.
1469
+ struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
1470
+ ggml_set_name(sorted_idx, "top_p_sorted_idx");
1471
+
1472
+ // Do the sorting via reshape + get_rows
1473
+ struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
1474
+ ggml_set_name(sorted_logits, "top_p_sorted_logits");
1475
+
1476
+ struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
1477
+ ggml_set_name(softmax, "top_p_softmax");
1478
+
1479
+ // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
1480
+ if (data->candidates) {
1481
+ data->candidates = ggml_sort(data->candidates, sorted_idx);
1482
+ } else {
1483
+ data->candidates = sorted_idx;
1484
+ }
1485
+ ggml_set_name(data->candidates, "top_p_candidates");
1486
+
1487
+ // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
1488
+ struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
1489
+ ggml_set_name(cdf, "top_p_cdf");
1490
+
1491
+ // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
1492
+ struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
1493
+ ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
1494
+
1495
+ struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
1496
+ ggml_set_name(mask, "top_p_mask");
1497
+
1498
+ // Taking the sum of the mask gives us the sum of elements after the threshold
1499
+ // we are interested in.
1500
+ struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1501
+ ggml_set_name(idxf, "top_p_index_f32");
1502
+
1503
+ // prevent out-of-bounds access
1504
+ idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
1505
+
1506
+ // construct ones tensor to set the value in the mask
1507
+ struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
1508
+ ggml_set_name(ones, "top_p_ones");
1509
+
1510
+ // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
1511
+ struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
1512
+
1513
+ mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
1514
+ mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
1515
+
1516
+ // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
1517
+ // top_p_bias = (mask * 1e9f) - 1e9f.
1518
+ // So entries in the mask that we want to discard will become -1e9f, and
1519
+ // others will be 0 (meaning that will not effect the logits).
1520
+ const float large_val = 1e9f;
1521
+ struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
1522
+ ggml_set_name(top_p_bias, "top_p_bias");
1523
+
1524
+ data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
1525
+ ggml_set_name(data->logits, "top_p_logits");
1526
+
1527
+ GGML_UNUSED(gf);
1528
+ }
1529
+
840
1530
  static struct llama_sampler_i llama_sampler_top_p_i = {
841
- /* .name = */ llama_sampler_top_p_name,
842
- /* .accept = */ nullptr,
843
- /* .apply = */ llama_sampler_top_p_apply,
844
- /* .reset = */ nullptr,
845
- /* .clone = */ llama_sampler_top_p_clone,
846
- /* .free = */ llama_sampler_top_p_free,
1531
+ /* .name = */ llama_sampler_top_p_name,
1532
+ /* .accept = */ nullptr,
1533
+ /* .apply = */ llama_sampler_top_p_apply,
1534
+ /* .reset = */ nullptr,
1535
+ /* .clone = */ llama_sampler_top_p_clone,
1536
+ /* .free = */ llama_sampler_top_p_free,
1537
+ /* .backend_init = */ llama_sampler_top_p_backend_init,
1538
+ /* .backend_accept = */ nullptr,
1539
+ /* .backend_apply = */ llama_sampler_top_p_backend_apply,
1540
+ /* .backend_set_input = */ nullptr,
847
1541
  };
848
1542
 
849
1543
  struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
1544
+ const bool is_empty = p >= 1.0f;
1545
+
1546
+ if (is_empty) {
1547
+ return llama_sampler_init_empty("?top-p");
1548
+ }
1549
+
850
1550
  return llama_sampler_init(
851
1551
  /* .iface = */ &llama_sampler_top_p_i,
852
1552
  /* .ctx = */ new llama_sampler_top_p {
1553
+ ("top-p"),
853
1554
  /* .p = */ p,
854
1555
  /* .min_keep = */ min_keep,
855
1556
  /* .buf_sort = */ {},
@@ -859,13 +1560,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
859
1560
 
860
1561
  // min-p
861
1562
 
862
- struct llama_sampler_min_p {
1563
+ struct llama_sampler_min_p : public llama_sampler_backend {
863
1564
  const float p;
864
1565
  const size_t min_keep;
865
1566
  };
866
1567
 
867
- static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
868
- return "min-p";
1568
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
1569
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1570
+ return sctx->get_name();
869
1571
  }
870
1572
 
871
1573
  static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -931,19 +1633,85 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
931
1633
  delete (llama_sampler_min_p *) smpl->ctx;
932
1634
  }
933
1635
 
1636
+ static bool llama_sampler_min_p_backend_init(
1637
+ struct llama_sampler * smpl,
1638
+ ggml_backend_buffer_type_t buft) {
1639
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1640
+
1641
+ const bool res = llama_sampler_backend_support(smpl, buft);
1642
+
1643
+ sctx->init(res);
1644
+
1645
+ return res;
1646
+ }
1647
+
1648
+ static void llama_sampler_min_p_backend_apply(
1649
+ struct llama_sampler * smpl,
1650
+ struct ggml_context * ctx,
1651
+ struct ggml_cgraph * gf,
1652
+ struct llama_sampler_data * data) {
1653
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1654
+
1655
+ struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1656
+ ggml_set_name(max_idx, "max_idx");
1657
+
1658
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1659
+ ggml_set_name(logits_rows, "logits_rows");
1660
+
1661
+ struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
1662
+ ggml_set_name(max_logit, "max_logit");
1663
+
1664
+ // Calculate the threshold value.
1665
+ struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
1666
+ ggml_set_name(threshold, "min_p_threshold");
1667
+
1668
+ // Subtract the threshold from logits.
1669
+ struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
1670
+
1671
+ // Create a mask where logits below the threshold are 0 (discard),
1672
+ // and others are 1 (keep).
1673
+ struct ggml_tensor * mask = ggml_step(ctx, sub);
1674
+ ggml_set_name(mask, "min_p_mask");
1675
+
1676
+ // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
1677
+ // min_p_bias = (mask * 1e9f) - 1e9f.
1678
+ // So entries in the mask that we want to discard will become -1e9f, and
1679
+ // others will be 0 (meaning that will not effect the logits).
1680
+ const float large_val = 1e9f;
1681
+ struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
1682
+ ggml_set_name(min_p_bias, "min_p_bias");
1683
+
1684
+ // Add the min_p bias to the logits.
1685
+ data->logits = ggml_add(ctx, data->logits, min_p_bias);
1686
+ ggml_set_name(data->logits, "min_p_logits");
1687
+
1688
+ GGML_UNUSED(gf);
1689
+ }
1690
+
934
1691
  static struct llama_sampler_i llama_sampler_min_p_i = {
935
- /* .name = */ llama_sampler_min_p_name,
936
- /* .accept = */ nullptr,
937
- /* .apply = */ llama_sampler_min_p_apply,
938
- /* .reset = */ nullptr,
939
- /* .clone = */ llama_sampler_min_p_clone,
940
- /* .free = */ llama_sampler_min_p_free,
1692
+ /* .name = */ llama_sampler_min_p_name,
1693
+ /* .accept = */ nullptr,
1694
+ /* .apply = */ llama_sampler_min_p_apply,
1695
+ /* .reset = */ nullptr,
1696
+ /* .clone = */ llama_sampler_min_p_clone,
1697
+ /* .free = */ llama_sampler_min_p_free,
1698
+ /* .backend_init = */ llama_sampler_min_p_backend_init,
1699
+ /* .backend_accept = */ nullptr,
1700
+ /* .backend_apply = */ llama_sampler_min_p_backend_apply,
1701
+ /* .backend_set_input = */ nullptr,
941
1702
  };
942
1703
 
943
1704
  struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
1705
+ const bool is_empty = (p <= 0.0f);
1706
+
1707
+ if (is_empty) {
1708
+ return llama_sampler_init_empty("?min-p");
1709
+ }
1710
+
944
1711
  return llama_sampler_init(
945
1712
  /* .iface = */ &llama_sampler_min_p_i,
946
1713
  /* .ctx = */ new llama_sampler_min_p {
1714
+ ("min-p"),
947
1715
  /* .p = */ p,
948
1716
  /* .min_keep = */ min_keep,
949
1717
  }
@@ -1031,15 +1799,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
1031
1799
  }
1032
1800
 
1033
1801
  static struct llama_sampler_i llama_sampler_typical_i = {
1034
- /* .name = */ llama_sampler_typical_name,
1035
- /* .accept = */ nullptr,
1036
- /* .apply = */ llama_sampler_typical_apply,
1037
- /* .reset = */ nullptr,
1038
- /* .clone = */ llama_sampler_typical_clone,
1039
- /* .free = */ llama_sampler_typical_free,
1802
+ /* .name = */ llama_sampler_typical_name,
1803
+ /* .accept = */ nullptr,
1804
+ /* .apply = */ llama_sampler_typical_apply,
1805
+ /* .reset = */ nullptr,
1806
+ /* .clone = */ llama_sampler_typical_clone,
1807
+ /* .free = */ llama_sampler_typical_free,
1808
+ /* .backend_init = */ nullptr,
1809
+ /* .backend_accept = */ nullptr,
1810
+ /* .backend_apply = */ nullptr,
1811
+ /* .backend_set_input = */ nullptr,
1040
1812
  };
1041
1813
 
1042
1814
  struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1815
+ const bool is_empty = (p >= 1.0f);
1816
+
1817
+ if (is_empty) {
1818
+ return llama_sampler_init_empty("?typical");
1819
+ }
1820
+
1043
1821
  return llama_sampler_init(
1044
1822
  /* .iface = */ &llama_sampler_typical_i,
1045
1823
  /* .ctx = */ new llama_sampler_typical {
@@ -1051,12 +1829,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1051
1829
 
1052
1830
  // temp
1053
1831
 
1054
- struct llama_sampler_temp {
1832
+ struct llama_sampler_temp : public llama_sampler_backend {
1055
1833
  const float temp;
1056
1834
  };
1057
1835
 
1058
- static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
1059
- return "temp";
1836
+ static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
1837
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1838
+ return sctx->get_name();
1060
1839
  }
1061
1840
 
1062
1841
  static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -1074,19 +1853,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1074
1853
  delete (llama_sampler_temp *) smpl->ctx;
1075
1854
  }
1076
1855
 
1856
+ static void llama_sampler_backend_temp_sampling(
1857
+ struct ggml_context * ctx,
1858
+ struct ggml_cgraph * gf,
1859
+ struct llama_sampler_data * data,
1860
+ float temp) {
1861
+ if (temp <= 0.0f) {
1862
+ // Find the most probable token index.
1863
+ struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1864
+ ggml_set_name(max_idx, "temp_max_idx");
1865
+
1866
+ if (data->candidates) {
1867
+ struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1868
+ data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
1869
+ } else {
1870
+ data->candidates = max_idx;
1871
+ }
1872
+
1873
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1874
+ data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
1875
+
1876
+ return;
1877
+ }
1878
+
1879
+ data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
1880
+
1881
+ GGML_UNUSED(gf);
1882
+ }
1883
+
1884
+ static bool llama_sampler_temp_backend_init(
1885
+ struct llama_sampler * smpl,
1886
+ ggml_backend_buffer_type_t buft) {
1887
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1888
+
1889
+ const bool res = llama_sampler_backend_support(smpl, buft);
1890
+
1891
+ sctx->init(res);
1892
+
1893
+ return res;
1894
+ }
1895
+
1896
+ static void llama_sampler_temp_backend_apply(
1897
+ struct llama_sampler * smpl,
1898
+ struct ggml_context * ctx,
1899
+ struct ggml_cgraph * gf,
1900
+ struct llama_sampler_data * data) {
1901
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1902
+ llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
1903
+ }
1904
+
1077
1905
  static struct llama_sampler_i llama_sampler_temp_i = {
1078
- /* .name = */ llama_sampler_temp_name,
1079
- /* .accept = */ nullptr,
1080
- /* .apply = */ llama_sampler_temp_apply,
1081
- /* .reset = */ nullptr,
1082
- /* .clone = */ llama_sampler_temp_clone,
1083
- /* .free = */ llama_sampler_temp_free,
1906
+ /* .name = */ llama_sampler_temp_name,
1907
+ /* .accept = */ nullptr,
1908
+ /* .apply = */ llama_sampler_temp_apply,
1909
+ /* .reset = */ nullptr,
1910
+ /* .clone = */ llama_sampler_temp_clone,
1911
+ /* .free = */ llama_sampler_temp_free,
1912
+ /* .backend_init = */ llama_sampler_temp_backend_init,
1913
+ /* .backend_accept = */ nullptr,
1914
+ /* .backend_apply = */ llama_sampler_temp_backend_apply,
1915
+ /* .backend_set_input = */ nullptr,
1084
1916
  };
1085
1917
 
1086
1918
  struct llama_sampler * llama_sampler_init_temp(float temp) {
1919
+ const bool is_empty = temp == 1.0f;
1920
+
1921
+ if (is_empty) {
1922
+ return llama_sampler_init_empty("?temp");
1923
+ }
1924
+
1087
1925
  return llama_sampler_init(
1088
1926
  /* .iface = */ &llama_sampler_temp_i,
1089
1927
  /* .ctx = */ new llama_sampler_temp {
1928
+ ("temp"),
1090
1929
  /*.temp = */ temp,
1091
1930
  }
1092
1931
  );
@@ -1094,14 +1933,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
1094
1933
 
1095
1934
  // temp-ext
1096
1935
 
1097
- struct llama_sampler_temp_ext {
1936
+ struct llama_sampler_temp_ext : public llama_sampler_backend {
1098
1937
  const float temp;
1099
1938
  const float delta;
1100
1939
  const float exponent;
1101
1940
  };
1102
1941
 
1103
- static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1104
- return "temp-ext";
1942
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
1943
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
1944
+ return sctx->get_name();
1105
1945
  }
1106
1946
 
1107
1947
  static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -1184,24 +2024,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1184
2024
  delete (llama_sampler_temp_ext *) smpl->ctx;
1185
2025
  }
1186
2026
 
2027
+ static bool llama_sampler_temp_ext_backend_init(
2028
+ struct llama_sampler * smpl,
2029
+ ggml_backend_buffer_type_t buft) {
2030
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2031
+
2032
+ const bool res = llama_sampler_backend_support(smpl, buft);
2033
+
2034
+ sctx->init(res);
2035
+
2036
+ return res;
2037
+ }
2038
+
2039
+ static void llama_sampler_temp_ext_backend_apply(
2040
+ struct llama_sampler * smpl,
2041
+ struct ggml_context * ctx,
2042
+ struct ggml_cgraph * gf,
2043
+ struct llama_sampler_data * data) {
2044
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2045
+
2046
+ // Revert to standard temperature scaling if delta or temp are non-positive.
2047
+ if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
2048
+ llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
2049
+ return;
2050
+ }
2051
+
2052
+ // Calculate min_temp, max_temp, and max_entropy.
2053
+ const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
2054
+ const float max_temp = sctx->temp + sctx->delta;
2055
+ const float max_entropy = logf(data->logits->ne[0]);
2056
+
2057
+ // Calculate the probabilities.
2058
+ struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
2059
+ ggml_set_name(probs, "temp_ext_softmax_probs");
2060
+
2061
+ // Clamp probabilities to avoid log(0) which would give -inf
2062
+ struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
2063
+ ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
2064
+
2065
+ // Calculate the entropy, entropy = -Σ(p * log(p)).
2066
+ struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
2067
+ struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
2068
+ struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
2069
+ struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
2070
+ ggml_set_name(log_probs, "temp_ext_log_probs");
2071
+ ggml_set_name(p_log_p, "temp_ext_p_log_p");
2072
+ ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
2073
+ ggml_set_name(entropy, "temp_ext_entropy");
2074
+
2075
+ // Normalize the entropy, norm_entropy = entropy / max_entropy
2076
+ struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
2077
+ ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
2078
+
2079
+ // Calculate the dynamic temperature:
2080
+ // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
2081
+ //
2082
+ // Calculate powf(normalized_entropy, exponent) as
2083
+ // norm_entropy^exponent = exp(exponent * log(norm_entropy))
2084
+ struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
2085
+ struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
2086
+ struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
2087
+ // With pow_entropy computed we can now compute dyn_temp, scaling by
2088
+ // (max_temp - min_temp) and then adding min_temp.
2089
+ struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
2090
+ ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
2091
+ ggml_set_name(scaled_log, "temp_ext_scaled_log");
2092
+ ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
2093
+ ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
2094
+
2095
+ // Scale the logits by the dynamic temperature
2096
+ struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
2097
+ ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
2098
+
2099
+ data->logits = scaled_logits;
2100
+ }
2101
+
1187
2102
  static struct llama_sampler_i llama_sampler_temp_ext_i = {
1188
- /* .name = */ llama_sampler_temp_ext_name,
1189
- /* .accept = */ nullptr,
1190
- /* .apply = */ llama_sampler_temp_ext_apply,
1191
- /* .reset = */ nullptr,
1192
- /* .clone = */ llama_sampler_temp_ext_clone,
1193
- /* .free = */ llama_sampler_temp_ext_free,
2103
+ /* .name = */ llama_sampler_temp_ext_name,
2104
+ /* .accept = */ nullptr,
2105
+ /* .apply = */ llama_sampler_temp_ext_apply,
2106
+ /* .reset = */ nullptr,
2107
+ /* .clone = */ llama_sampler_temp_ext_clone,
2108
+ /* .free = */ llama_sampler_temp_ext_free,
2109
+ /* .backend_init = */ llama_sampler_temp_ext_backend_init,
2110
+ /* .backend_accept = */ nullptr,
2111
+ /* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
2112
+ /* .backend_set_input = */ nullptr,
1194
2113
  };
1195
2114
 
1196
2115
  struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1197
- return llama_sampler_init(
2116
+ const bool is_empty = temp == 1.0f && delta <= 0.0f;
2117
+
2118
+ if (is_empty) {
2119
+ return llama_sampler_init_empty("?temp-ext");
2120
+ }
2121
+
2122
+ auto * res = llama_sampler_init(
1198
2123
  /* .iface = */ &llama_sampler_temp_ext_i,
1199
2124
  /* .ctx = */ new llama_sampler_temp_ext {
2125
+ ("temp-ext"),
1200
2126
  /* .temp = */ temp,
1201
2127
  /* .delta = */ delta,
1202
2128
  /* .exponent = */ exponent,
1203
2129
  }
1204
2130
  );
2131
+
2132
+ return res;
1205
2133
  }
1206
2134
 
1207
2135
  // xtc
@@ -1214,7 +2142,7 @@ struct llama_sampler_xtc {
1214
2142
  const uint32_t seed;
1215
2143
  uint32_t seed_cur;
1216
2144
 
1217
- std::mt19937 rng;
2145
+ std::mt19937 rng;
1218
2146
  };
1219
2147
 
1220
2148
  static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@@ -1279,16 +2207,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1279
2207
  }
1280
2208
 
1281
2209
  static struct llama_sampler_i llama_sampler_xtc_i = {
1282
- /* .name = */ llama_sampler_xtc_name,
1283
- /* .accept = */ nullptr,
1284
- /* .apply = */ llama_sample_xtc_apply,
1285
- /* .reset = */ llama_sampler_xtc_reset,
1286
- /* .clone = */ llama_sampler_xtc_clone,
1287
- /* .free = */ llama_sampler_xtc_free,
2210
+ /* .name = */ llama_sampler_xtc_name,
2211
+ /* .accept = */ nullptr,
2212
+ /* .apply = */ llama_sample_xtc_apply,
2213
+ /* .reset = */ llama_sampler_xtc_reset,
2214
+ /* .clone = */ llama_sampler_xtc_clone,
2215
+ /* .free = */ llama_sampler_xtc_free,
2216
+ /* .backend_init = */ nullptr,
2217
+ /* .backend_accept = */ nullptr,
2218
+ /* .backend_apply = */ nullptr,
2219
+ /* .backend_set_input = */ nullptr,
1288
2220
  };
1289
2221
 
1290
2222
  struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1291
- auto seed_cur = get_rng_seed(seed);
2223
+ const bool is_empty = (p <= 0.0f || t > 0.5f);
2224
+
2225
+ if (is_empty) {
2226
+ return llama_sampler_init_empty("?xtc");
2227
+ }
2228
+
2229
+ const auto seed_cur = get_rng_seed(seed);
2230
+
1292
2231
  return llama_sampler_init(
1293
2232
  /* .iface = */ &llama_sampler_xtc_i,
1294
2233
  /* .ctx = */ new llama_sampler_xtc {
@@ -1387,16 +2326,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1387
2326
  }
1388
2327
 
1389
2328
  static struct llama_sampler_i llama_sampler_mirostat_i = {
1390
- /* .name = */ llama_sampler_mirostat_name,
1391
- /* .accept = */ nullptr,
1392
- /* .apply = */ llama_sampler_mirostat_apply,
1393
- /* .reset = */ llama_sampler_mirostat_reset,
1394
- /* .clone = */ llama_sampler_mirostat_clone,
1395
- /* .free = */ llama_sampler_mirostat_free,
2329
+ /* .name = */ llama_sampler_mirostat_name,
2330
+ /* .accept = */ nullptr,
2331
+ /* .apply = */ llama_sampler_mirostat_apply,
2332
+ /* .reset = */ llama_sampler_mirostat_reset,
2333
+ /* .clone = */ llama_sampler_mirostat_clone,
2334
+ /* .free = */ llama_sampler_mirostat_free,
2335
+ /* .backend_init = */ nullptr,
2336
+ /* .backend_accept = */ nullptr,
2337
+ /* .backend_apply = */ nullptr,
2338
+ /* .backend_set_input = */ nullptr,
1396
2339
  };
1397
2340
 
1398
2341
  struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1399
- auto seed_cur = get_rng_seed(seed);
2342
+ const auto seed_cur = get_rng_seed(seed);
2343
+
1400
2344
  return llama_sampler_init(
1401
2345
  /* .iface = */ &llama_sampler_mirostat_i,
1402
2346
  /* .ctx = */ new llama_sampler_mirostat {
@@ -1486,12 +2430,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1486
2430
  }
1487
2431
 
1488
2432
  static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1489
- /* .name = */ llama_sampler_mirostat_v2_name,
1490
- /* .accept = */ nullptr,
1491
- /* .apply = */ llama_sampler_mirostat_v2_apply,
1492
- /* .reset = */ llama_sampler_mirostat_v2_reset,
1493
- /* .clone = */ llama_sampler_mirostat_v2_clone,
1494
- /* .free = */ llama_sampler_mirostat_v2_free,
2433
+ /* .name = */ llama_sampler_mirostat_v2_name,
2434
+ /* .accept = */ nullptr,
2435
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
2436
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
2437
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
2438
+ /* .free = */ llama_sampler_mirostat_v2_free,
2439
+ /* .backend_init = */ nullptr,
2440
+ /* .backend_accept = */ nullptr,
2441
+ /* .backend_apply = */ nullptr,
2442
+ /* .backend_set_input = */ nullptr,
1495
2443
  };
1496
2444
 
1497
2445
  struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@@ -1603,12 +2551,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1603
2551
  }
1604
2552
 
1605
2553
  static struct llama_sampler_i llama_sampler_grammar_i = {
1606
- /* .name = */ llama_sampler_grammar_name,
1607
- /* .accept = */ llama_sampler_grammar_accept_impl,
1608
- /* .apply = */ llama_sampler_grammar_apply,
1609
- /* .reset = */ llama_sampler_grammar_reset,
1610
- /* .clone = */ llama_sampler_grammar_clone,
1611
- /* .free = */ llama_sampler_grammar_free,
2554
+ /* .name = */ llama_sampler_grammar_name,
2555
+ /* .accept = */ llama_sampler_grammar_accept_impl,
2556
+ /* .apply = */ llama_sampler_grammar_apply,
2557
+ /* .reset = */ llama_sampler_grammar_reset,
2558
+ /* .clone = */ llama_sampler_grammar_clone,
2559
+ /* .free = */ llama_sampler_grammar_free,
2560
+ /* .backend_init = */ nullptr,
2561
+ /* .backend_accept = */ nullptr,
2562
+ /* .backend_apply = */ nullptr,
2563
+ /* .backend_set_input = */ nullptr,
1612
2564
  };
1613
2565
 
1614
2566
  static struct llama_sampler * llama_sampler_init_grammar_impl(
@@ -1625,10 +2577,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1625
2577
  auto * ctx = new llama_sampler_grammar;
1626
2578
 
1627
2579
  if (grammar_str != nullptr && grammar_str[0] != '\0') {
2580
+ std::string trigger_pattern;
2581
+ llama_grammar * grammar = nullptr;
1628
2582
  // TODO: remove trigger_words support.
1629
2583
  if (trigger_words != nullptr && num_trigger_words > 0) {
1630
2584
  GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
1631
- std::string trigger_pattern("[\\s\\S]*?(");
2585
+ trigger_pattern = "[\\s\\S]*?(";
1632
2586
  for (size_t i = 0; i < num_trigger_words; ++i) {
1633
2587
  static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
1634
2588
  if (i > 0) {
@@ -1637,15 +2591,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1637
2591
  trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
1638
2592
  }
1639
2593
  trigger_pattern += ")[\\s\\S]*";
1640
- const auto * trigger_pattern_c = trigger_pattern.c_str();
1641
- trigger_patterns = &trigger_pattern_c;
1642
- num_trigger_patterns = 1;
2594
+
2595
+ std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
2596
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
2597
+ } else {
2598
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
1643
2599
  }
1644
2600
  *ctx = {
1645
2601
  /* .vocab = */ vocab,
1646
2602
  /* .grammar_str = */ grammar_str,
1647
2603
  /* .grammar_root = */ grammar_root,
1648
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
2604
+ /* .grammar = */ grammar,
1649
2605
  };
1650
2606
  if (!ctx->grammar) {
1651
2607
  delete ctx;
@@ -1806,12 +2762,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1806
2762
  }
1807
2763
 
1808
2764
  static struct llama_sampler_i llama_sampler_penalties_i = {
1809
- /* .name = */ llama_sampler_penalties_name,
1810
- /* .accept = */ llama_sampler_penalties_accept,
1811
- /* .apply = */ llama_sampler_penalties_apply,
1812
- /* .reset = */ llama_sampler_penalties_reset,
1813
- /* .clone = */ llama_sampler_penalties_clone,
1814
- /* .free = */ llama_sampler_penalties_free,
2765
+ /* .name = */ llama_sampler_penalties_name,
2766
+ /* .accept = */ llama_sampler_penalties_accept,
2767
+ /* .apply = */ llama_sampler_penalties_apply,
2768
+ /* .reset = */ llama_sampler_penalties_reset,
2769
+ /* .clone = */ llama_sampler_penalties_clone,
2770
+ /* .free = */ llama_sampler_penalties_free,
2771
+ /* .backend_init = */ nullptr,
2772
+ /* .backend_accept = */ nullptr,
2773
+ /* .backend_apply = */ nullptr,
2774
+ /* .backend_set_input = */ nullptr,
1815
2775
  };
1816
2776
 
1817
2777
  struct llama_sampler * llama_sampler_init_penalties(
@@ -1821,6 +2781,12 @@ struct llama_sampler * llama_sampler_init_penalties(
1821
2781
  float penalty_present) {
1822
2782
  penalty_last_n = std::max(penalty_last_n, 0);
1823
2783
 
2784
+ const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
2785
+
2786
+ if (is_empty) {
2787
+ return llama_sampler_init_empty("?penalties");
2788
+ }
2789
+
1824
2790
  return llama_sampler_init(
1825
2791
  /* .iface = */ &llama_sampler_penalties_i,
1826
2792
  /* .ctx = */ new llama_sampler_penalties {
@@ -1858,9 +2824,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
1858
2824
  for (size_t i = 0; i < cur_p->size; ++i) {
1859
2825
  // Only count non-negative infinity values
1860
2826
  if (cur_p->data[i].logit != -INFINITY) {
1861
- if (cur_p->data[i].logit > max) {
1862
- max = cur_p->data[i].logit;
1863
- }
2827
+ max = std::max(max, cur_p->data[i].logit);
1864
2828
  logits_sum += cur_p->data[i].logit;
1865
2829
  valid_count++;
1866
2830
  }
@@ -1897,15 +2861,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1897
2861
  }
1898
2862
 
1899
2863
  static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1900
- /* .name = */ llama_sampler_top_n_sigma_name,
1901
- /* .accept = */ nullptr,
1902
- /* .apply = */ llama_sampler_top_n_sigma_apply,
1903
- /* .reset = */ nullptr,
1904
- /* .clone = */ llama_sampler_top_n_sigma_clone,
1905
- /* .free = */ llama_sampler_top_n_sigma_free,
2864
+ /* .name = */ llama_sampler_top_n_sigma_name,
2865
+ /* .accept = */ nullptr,
2866
+ /* .apply = */ llama_sampler_top_n_sigma_apply,
2867
+ /* .reset = */ nullptr,
2868
+ /* .clone = */ llama_sampler_top_n_sigma_clone,
2869
+ /* .free = */ llama_sampler_top_n_sigma_free,
2870
+ /* .backend_init = */ nullptr,
2871
+ /* .backend_accept = */ nullptr,
2872
+ /* .backend_apply = */ nullptr,
2873
+ /* .backend_set_input = */ nullptr,
1906
2874
  };
1907
2875
 
1908
2876
  struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
2877
+ const bool is_empty = (n <= 0.0f);
2878
+
2879
+ if (is_empty) {
2880
+ return llama_sampler_init_empty("?top-n-sigma");
2881
+ }
2882
+
1909
2883
  return llama_sampler_init(
1910
2884
  /* .iface = */ &llama_sampler_top_n_sigma_i,
1911
2885
  /* .ctx = */ new llama_sampler_top_n_sigma {
@@ -2227,12 +3201,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
2227
3201
  }
2228
3202
 
2229
3203
  static struct llama_sampler_i llama_sampler_dry_i = {
2230
- /* .name = */ llama_sampler_dry_name,
2231
- /* .accept = */ llama_sampler_dry_accept,
2232
- /* .apply = */ llama_sampler_dry_apply,
2233
- /* .reset = */ llama_sampler_dry_reset,
2234
- /* .clone = */ llama_sampler_dry_clone,
2235
- /* .free = */ llama_sampler_dry_free,
3204
+ /* .name = */ llama_sampler_dry_name,
3205
+ /* .accept = */ llama_sampler_dry_accept,
3206
+ /* .apply = */ llama_sampler_dry_apply,
3207
+ /* .reset = */ llama_sampler_dry_reset,
3208
+ /* .clone = */ llama_sampler_dry_clone,
3209
+ /* .free = */ llama_sampler_dry_free,
3210
+ /* .backend_init = */ nullptr,
3211
+ /* .backend_accept = */ nullptr,
3212
+ /* .backend_apply = */ nullptr,
3213
+ /* .backend_set_input = */ nullptr,
2236
3214
  };
2237
3215
 
2238
3216
  struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
@@ -2243,6 +3221,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
2243
3221
 
2244
3222
  const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2245
3223
 
3224
+ if (!dry_enabled) {
3225
+ return llama_sampler_init_empty("?dry");
3226
+ }
3227
+
2246
3228
  if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2247
3229
  // Process sequence breakers
2248
3230
  for (size_t i = 0; i < num_breakers; ++i) {
@@ -2313,16 +3295,23 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
2313
3295
 
2314
3296
  // logit-bias
2315
3297
 
2316
- struct llama_sampler_logit_bias {
3298
+ struct llama_sampler_logit_bias : public llama_sampler_backend {
2317
3299
  const int32_t n_vocab;
2318
3300
 
2319
3301
  const std::vector<llama_logit_bias> logit_bias;
2320
3302
 
2321
3303
  std::vector<llama_logit_bias> to_search;
3304
+
3305
+ struct ggml_tensor * inp_logit_bias;
3306
+ struct ggml_tensor * inp_logit_idxs;
3307
+
3308
+ ggml_context_ptr inp_ctx;
3309
+ ggml_backend_buffer_ptr inp_buf;
2322
3310
  };
2323
3311
 
2324
- static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
2325
- return "logit-bias";
3312
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
3313
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
3314
+ return ctx->get_name();
2326
3315
  }
2327
3316
 
2328
3317
  static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -2367,25 +3356,123 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2367
3356
  delete (llama_sampler_logit_bias *) smpl->ctx;
2368
3357
  }
2369
3358
 
3359
+ static void llama_sampler_logit_bias_backend_apply(
3360
+ struct llama_sampler * smpl,
3361
+ struct ggml_context * ctx,
3362
+ struct ggml_cgraph * gf,
3363
+ struct llama_sampler_data * data) {
3364
+ GGML_UNUSED(gf);
3365
+ GGML_UNUSED(ctx);
3366
+
3367
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3368
+ if (sctx->logit_bias.empty()) {
3369
+ return;
3370
+ }
3371
+
3372
+ ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
3373
+
3374
+ cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
3375
+ cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
3376
+ cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
3377
+
3378
+ data->logits = ggml_add(ctx, data->logits, cur);
3379
+ }
3380
+
3381
+ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
3382
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3383
+ if (sctx->logit_bias.empty()) {
3384
+ return;
3385
+ }
3386
+
3387
+ GGML_ASSERT(sctx->inp_logit_bias != nullptr);
3388
+ GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
3389
+
3390
+ const size_t n = sctx->logit_bias.size();
3391
+
3392
+ std::vector<float> data_logit_bias(n, 0.0f);
3393
+ std::vector<int32_t> data_logit_idxs(n, 0);
3394
+ for (size_t i = 0; i < n; ++i) {
3395
+ const auto & lb = sctx->logit_bias[i];
3396
+ GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
3397
+ data_logit_bias[i] = lb.bias;
3398
+ data_logit_idxs[i] = lb.token;
3399
+ }
3400
+
3401
+ ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
3402
+ ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
3403
+ }
3404
+
3405
+ static bool llama_sampler_logit_bias_backend_init(
3406
+ struct llama_sampler * smpl,
3407
+ ggml_backend_buffer_type_t buft) {
3408
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3409
+
3410
+ sctx->init(true);
3411
+
3412
+ if (sctx->logit_bias.empty()) {
3413
+ return true;
3414
+ }
3415
+
3416
+ ggml_init_params params = {
3417
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
3418
+ /*.mem_buffer =*/ nullptr,
3419
+ /*.no_alloc =*/ true,
3420
+ };
3421
+
3422
+ sctx->inp_ctx.reset(ggml_init(params));
3423
+
3424
+ const size_t n = sctx->logit_bias.size();
3425
+
3426
+ sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
3427
+ ggml_set_name(sctx->inp_logit_bias, "logit_bias");
3428
+ ggml_set_input(sctx->inp_logit_bias);
3429
+
3430
+ sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
3431
+ ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
3432
+ ggml_set_input(sctx->inp_logit_idxs);
3433
+
3434
+ // Allocate all tensors from our context to the backend
3435
+ sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
3436
+
3437
+ ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
3438
+
3439
+ return true;
3440
+ }
3441
+
2370
3442
  static struct llama_sampler_i llama_sampler_logit_bias_i = {
2371
- /* .name = */ llama_sampler_logit_bias_name,
2372
- /* .accept = */ nullptr,
2373
- /* .apply = */ llama_sampler_logit_bias_apply,
2374
- /* .reset = */ nullptr,
2375
- /* .clone = */ llama_sampler_logit_bias_clone,
2376
- /* .free = */ llama_sampler_logit_bias_free,
3443
+ /* .name = */ llama_sampler_logit_bias_name,
3444
+ /* .accept = */ nullptr,
3445
+ /* .apply = */ llama_sampler_logit_bias_apply,
3446
+ /* .reset = */ nullptr,
3447
+ /* .clone = */ llama_sampler_logit_bias_clone,
3448
+ /* .free = */ llama_sampler_logit_bias_free,
3449
+ /* .backend_init = */ llama_sampler_logit_bias_backend_init,
3450
+ /* .backend_accept = */ nullptr,
3451
+ /* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
3452
+ /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
2377
3453
  };
2378
3454
 
2379
3455
  struct llama_sampler * llama_sampler_init_logit_bias(
2380
3456
  int32_t n_vocab,
2381
3457
  int32_t n_logit_bias,
2382
3458
  const llama_logit_bias * logit_bias) {
3459
+ const bool is_empty = n_logit_bias <= 0;
3460
+
3461
+ if (is_empty) {
3462
+ return llama_sampler_init_empty("?logit-bias");
3463
+ }
3464
+
2383
3465
  return llama_sampler_init(
2384
3466
  /* .iface = */ &llama_sampler_logit_bias_i,
2385
3467
  /* .ctx = */ new llama_sampler_logit_bias {
2386
- /* .n_vocab = */ n_vocab,
2387
- /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2388
- /* .to_search = */ {},
3468
+ ("logit-bias"),
3469
+ /* .n_vocab = */ n_vocab,
3470
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
3471
+ /* .to_search = */ {},
3472
+ /* .inp_logit_bias = */ nullptr,
3473
+ /* .inp_logit_idxs = */ nullptr,
3474
+ /* .inp_ctx = */ nullptr,
3475
+ /* .inp_buf = */ nullptr,
2389
3476
  }
2390
3477
  );
2391
3478
  }
@@ -2541,8 +3628,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2541
3628
  if (n_non_eog == 0) {
2542
3629
  cur_p->size = 1;
2543
3630
  cur_p->data[0].id = ctx->vocab->token_eot();
3631
+ if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
3632
+ cur_p->data[0].id = ctx->vocab->token_eos();
3633
+ }
2544
3634
  cur_p->data[0].logit = 1.0f;
2545
3635
 
3636
+ GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
3637
+
2546
3638
  return;
2547
3639
  }
2548
3640
 
@@ -2593,12 +3685,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2593
3685
  }
2594
3686
 
2595
3687
  static struct llama_sampler_i llama_sampler_infill_i = {
2596
- /* .name = */ llama_sampler_infill_name,
2597
- /* .accept = */ nullptr,
2598
- /* .apply = */ llama_sampler_infill_apply,
2599
- /* .reset = */ nullptr,
2600
- /* .clone = */ llama_sampler_infill_clone,
2601
- /* .free = */ llama_sampler_infill_free,
3688
+ /* .name = */ llama_sampler_infill_name,
3689
+ /* .accept = */ nullptr,
3690
+ /* .apply = */ llama_sampler_infill_apply,
3691
+ /* .reset = */ nullptr,
3692
+ /* .clone = */ llama_sampler_infill_clone,
3693
+ /* .free = */ llama_sampler_infill_free,
3694
+ /* .backend_apply = */ nullptr,
3695
+ /* .backend_accept = */ nullptr,
3696
+ /* .backend_set_input = */ nullptr,
3697
+ /* .backend_init = */ nullptr,
2602
3698
  };
2603
3699
 
2604
3700
  struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
@@ -2630,7 +3726,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2630
3726
  if (smpl->iface == &llama_sampler_chain_i) {
2631
3727
  const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2632
3728
  for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2633
- const uint32_t seed = llama_sampler_get_seed(*it);
3729
+ const uint32_t seed = llama_sampler_get_seed(it->ptr);
2634
3730
  if (seed != LLAMA_DEFAULT_SEED) {
2635
3731
  return seed;
2636
3732
  }
@@ -2660,8 +3756,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
2660
3756
  void llama_perf_sampler_print(const struct llama_sampler * chain) {
2661
3757
  const auto data = llama_perf_sampler(chain);
2662
3758
 
2663
- LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2664
- __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
3759
+ LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
2665
3760
  }
2666
3761
 
2667
3762
  void llama_perf_sampler_reset(struct llama_sampler * chain) {
@@ -2671,5 +3766,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
2671
3766
 
2672
3767
  auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2673
3768
 
2674
- ctx->t_sample_us = ctx->n_sample = 0;
3769
+ ctx->t_sample_us = 0;
3770
+ ctx->n_sample = 0;
2675
3771
  }