whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -4,6 +4,9 @@
4
4
  #include "llama-vocab.h"
5
5
  #include "llama-grammar.h"
6
6
 
7
+ #include "ggml-cpp.h"
8
+
9
+ #include <array>
7
10
  #include <algorithm>
8
11
  #include <cassert>
9
12
  #include <cfloat>
@@ -128,6 +131,89 @@ struct ring_buffer {
128
131
  std::vector<T> data;
129
132
  };
130
133
 
134
+ // writes result in res, does not mutate cur
135
+ static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
136
+ static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
137
+ return a.logit > b.logit;
138
+ };
139
+
140
+ constexpr int nbuckets = 128;
141
+ constexpr float bucket_low = -10.0f;
142
+ constexpr float bucket_high = 10.0f;
143
+ constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
144
+ constexpr float bucket_inter = -bucket_low * bucket_scale;
145
+
146
+ std::vector<int> bucket_idx;
147
+ std::vector<int> histo(nbuckets, 0);
148
+
149
+ std::vector<llama_token_data*> bucket_ptrs;
150
+
151
+ bucket_idx.reserve(cur.size);
152
+
153
+ for (int i = 0; i < (int)cur.size; ++i) {
154
+ const float val = cur.data[i].logit;
155
+ int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
156
+ ib = std::max(0, std::min(nbuckets - 1, ib));
157
+ bucket_idx.push_back(ib);
158
+ ++histo[ib];
159
+ }
160
+ int nhave = 0;
161
+ int ib = nbuckets - 1;
162
+ for ( ; ib >= 0; --ib) {
163
+ nhave += histo[ib];
164
+ if (nhave >= npartial) {
165
+ break;
166
+ }
167
+ }
168
+ res.resize(nhave);
169
+ auto * ptr = res.data();
170
+ bucket_ptrs.reserve(nbuckets - ib);
171
+ for (int j = nbuckets - 1; j >= ib; --j) {
172
+ bucket_ptrs.push_back(ptr);
173
+ ptr += histo[j];
174
+ }
175
+ for (int i = 0; i < (int)cur.size; ++i) {
176
+ int j = bucket_idx[i];
177
+ if (j >= ib) {
178
+ *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
179
+ }
180
+ }
181
+
182
+ ptr = res.data();
183
+ int ndone = 0;
184
+ for (int j = nbuckets - 1; j > ib; --j) {
185
+ std::sort(ptr, ptr + histo[j], comp);
186
+ ptr += histo[j];
187
+ ndone += histo[j];
188
+ }
189
+ std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
190
+ }
191
+
192
+ // reduces the size of cur_p to npartial, keeping only the top npartial elements
193
+ static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
194
+ static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
195
+ return a.logit > b.logit;
196
+ };
197
+
198
+ if (npartial <= 128) {
199
+ std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
200
+
201
+ cur_p->size = npartial;
202
+ cur_p->sorted = true;
203
+
204
+ return;
205
+ }
206
+
207
+ std::vector<llama_token_data> tmp;
208
+
209
+ llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
210
+
211
+ std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
212
+
213
+ cur_p->size = npartial;
214
+ cur_p->sorted = true;
215
+ }
216
+
131
217
  static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
132
218
  // iterator for the probabilities
133
219
  #ifdef __GNUC__
@@ -200,18 +286,21 @@ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp)
200
286
  }
201
287
  }
202
288
 
203
- static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
289
+ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
204
290
  GGML_ASSERT(cur_p->size > 0);
205
291
 
206
- // Sort the logits in descending order
207
- if (!cur_p->sorted) {
208
- std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
209
- return a.logit > b.logit;
210
- });
211
- cur_p->sorted = true;
292
+ // Sort the logits in descending order if requested
293
+ if (do_sort && !cur_p->sorted) {
294
+ llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
212
295
  }
213
296
 
214
297
  float max_l = cur_p->data[0].logit;
298
+ if (!cur_p->sorted) {
299
+ for (size_t i = 1; i < cur_p->size; ++i) {
300
+ max_l = std::max(max_l, cur_p->data[i].logit);
301
+ }
302
+ }
303
+
215
304
  float cum_sum = 0.0f;
216
305
 
217
306
  for (size_t i = 0; i < cur_p->size; ++i) {
@@ -226,7 +315,6 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
226
315
  }
227
316
 
228
317
  static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
229
- // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
230
318
  // if (k >= (int32_t)cur_p->size) {
231
319
  // return;
232
320
  // }
@@ -239,64 +327,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
239
327
 
240
328
  // Sort scores in descending order
241
329
  if (!cur_p->sorted) {
242
- auto comp = [](const llama_token_data & a, const llama_token_data & b) {
243
- return a.logit > b.logit;
244
- };
245
- if (k <= 128) {
246
- std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
247
- } else {
248
- constexpr int nbuckets = 128;
249
- constexpr float bucket_low = -10.0f;
250
- constexpr float bucket_high = 10.0f;
251
- constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
252
- constexpr float bucket_inter = -bucket_low * bucket_scale;
253
-
254
- std::vector<int> bucket_idx(cur_p->size);
255
- std::vector<int> histo(nbuckets, 0);
256
-
257
- for (int i = 0; i < (int)cur_p->size; ++i) {
258
- const float val = cur_p->data[i].logit;
259
- int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
260
- ib = std::max(0, std::min(nbuckets - 1, ib));
261
- bucket_idx[i] = ib;
262
- ++histo[ib];
263
- }
264
- int nhave = 0;
265
- int ib = nbuckets - 1;
266
- for ( ; ib >= 0; --ib) {
267
- nhave += histo[ib];
268
- if (nhave >= k) {
269
- break;
270
- }
271
- }
272
- std::vector<llama_token_data> tmp_tokens(nhave);
273
- auto * ptr = tmp_tokens.data();
274
- std::vector<llama_token_data*> bucket_ptrs;
275
- bucket_ptrs.reserve(nbuckets - ib);
276
- for (int j = nbuckets - 1; j >= ib; --j) {
277
- bucket_ptrs.push_back(ptr);
278
- ptr += histo[j];
279
- }
280
- for (int i = 0; i < (int)cur_p->size; ++i) {
281
- int j = bucket_idx[i];
282
- if (j >= ib) {
283
- *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
284
- }
285
- }
286
-
287
- ptr = tmp_tokens.data();
288
- int ndone = 0;
289
- for (int j = nbuckets - 1; j > ib; --j) {
290
- std::sort(ptr, ptr + histo[j], comp);
291
- ptr += histo[j];
292
- ndone += histo[j];
293
- }
294
- std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
295
-
296
- std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
297
-
298
- }
299
- cur_p->sorted = true;
330
+ llama_token_data_array_partial_sort_inplace(cur_p, k);
300
331
  }
301
332
 
302
333
  cur_p->size = k;
@@ -317,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
317
348
 
318
349
  // llama_sampler API
319
350
 
320
- struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
351
+ struct llama_sampler * llama_sampler_init(
352
+ struct llama_sampler_i * iface,
353
+ llama_sampler_context_t ctx) {
321
354
  return new llama_sampler {
322
355
  /* .iface = */ iface,
323
356
  /* .ctx = */ ctx,
@@ -333,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
333
366
  }
334
367
 
335
368
  void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
369
+ if (!smpl) {
370
+ return;
371
+ }
372
+
336
373
  if (smpl->iface->accept) {
337
374
  smpl->iface->accept(smpl, token);
338
375
  }
339
376
  }
340
377
 
341
378
  void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
379
+ if (!smpl) {
380
+ return;
381
+ }
382
+
342
383
  GGML_ASSERT(smpl->iface->apply);
343
384
  smpl->iface->apply(smpl, cur_p);
344
385
  }
345
386
 
346
387
  void llama_sampler_reset(struct llama_sampler * smpl) {
388
+ if (!smpl) {
389
+ return;
390
+ }
391
+
347
392
  if (smpl->iface->reset) {
348
393
  smpl->iface->reset(smpl);
349
394
  }
350
395
  }
351
396
 
352
397
  struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
398
+ if (!smpl) {
399
+ return nullptr;
400
+ }
401
+
353
402
  if (smpl->iface->clone) {
354
403
  return smpl->iface->clone(smpl);
355
404
  }
@@ -376,37 +425,200 @@ void llama_sampler_free(struct llama_sampler * smpl) {
376
425
  delete smpl;
377
426
  }
378
427
 
379
- llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
380
- const auto * logits = llama_get_logits_ith(ctx, idx);
428
+ // empty sampler
381
429
 
382
- const llama_model * model = llama_get_model(ctx);
383
- const llama_vocab * vocab = llama_model_get_vocab(model);
430
+ struct llama_sampler_empty {
431
+ const char * name;
432
+ };
384
433
 
385
- const int n_vocab = llama_vocab_n_tokens(vocab);
434
+ static struct llama_sampler * llama_sampler_init_empty(const char * name);
435
+
436
+ static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
437
+ auto * ctx = (llama_sampler_empty *) smpl->ctx;
438
+ return ctx->name;
439
+ }
440
+
441
+ static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
442
+ GGML_UNUSED(smpl);
443
+ GGML_UNUSED(token);
444
+ }
445
+
446
+ static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
447
+ GGML_UNUSED(smpl);
448
+ GGML_UNUSED(cur_p);
449
+ }
450
+
451
+ static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
452
+ GGML_UNUSED(smpl);
453
+ }
454
+
455
+ static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
456
+ auto * ctx = (llama_sampler_empty *) smpl->ctx;
457
+ return llama_sampler_init_empty(ctx->name);
458
+ }
459
+
460
+ static void llama_sampler_empty_free(struct llama_sampler * smpl) {
461
+ delete (llama_sampler_empty *) smpl->ctx;
462
+ }
463
+
464
+ static bool llama_sampler_empty_backend_init(
465
+ struct llama_sampler * smpl,
466
+ ggml_backend_buffer_type_t buft) {
467
+ GGML_UNUSED(smpl);
468
+ GGML_UNUSED(buft);
469
+
470
+ return true;
471
+ }
472
+
473
+ static void llama_sampler_empty_backend_accept(
474
+ struct llama_sampler * smpl,
475
+ ggml_context * ctx,
476
+ ggml_cgraph * gf,
477
+ struct ggml_tensor * selected_token) {
478
+ GGML_UNUSED(smpl);
479
+ GGML_UNUSED(ctx);
480
+ GGML_UNUSED(gf);
481
+ GGML_UNUSED(selected_token);
482
+ }
483
+
484
+ static void llama_sampler_empty_backend_apply(
485
+ struct llama_sampler * smpl,
486
+ struct ggml_context * ctx,
487
+ struct ggml_cgraph * gf,
488
+ struct llama_sampler_data * data) {
489
+ GGML_UNUSED(smpl);
490
+ GGML_UNUSED(ctx);
491
+ GGML_UNUSED(gf);
492
+ GGML_UNUSED(data);
493
+ }
494
+
495
+ static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
496
+ GGML_UNUSED(smpl);
497
+ }
498
+
499
+ static struct llama_sampler_i llama_sampler_empty_i = {
500
+ /* .name = */ llama_sampler_empty_name,
501
+ /* .accept = */ llama_sampler_empty_accept,
502
+ /* .apply = */ llama_sampler_empty_apply,
503
+ /* .reset = */ llama_sampler_empty_reset,
504
+ /* .clone = */ llama_sampler_empty_clone,
505
+ /* .free = */ llama_sampler_empty_free,
506
+ /* .backend_init = */ llama_sampler_empty_backend_init,
507
+ /* .backend_accept = */ llama_sampler_empty_backend_accept,
508
+ /* .backend_apply = */ llama_sampler_empty_backend_apply,
509
+ /* .backend_set_input = */ llama_sampler_empty_backend_set_input,
510
+ };
511
+
512
+ struct llama_sampler * llama_sampler_init_empty(const char * name) {
513
+ return llama_sampler_init(
514
+ /* .iface = */ &llama_sampler_empty_i,
515
+ /* .ctx = */ new llama_sampler_empty {
516
+ /* .name = */ name,
517
+ }
518
+ );
519
+ }
520
+
521
+ // common backend sampler functionality
522
+ //
523
+ // +name : means that the sampler is support and will run on the backend
524
+ // -name : means that a ggml operator is not supported by the backend
525
+ //
526
+ struct llama_sampler_backend {
527
+ llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
386
528
 
387
- // TODO: do not allocate each time
388
- std::vector<llama_token_data> cur;
389
- cur.reserve(n_vocab);
390
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
391
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
529
+ const char * get_name() {
530
+ if (!is_init) {
531
+ return name.c_str();
532
+ }
533
+
534
+ if (support) {
535
+ name_ext = "+" + name;
536
+ } else {
537
+ name_ext = "-" + name;
538
+ }
539
+
540
+ return name_ext.c_str();
392
541
  }
393
542
 
394
- llama_token_data_array cur_p = {
395
- /* .data = */ cur.data(),
396
- /* .size = */ cur.size(),
397
- /* .selected = */ -1,
398
- /* .sorted = */ false,
543
+ void init(bool support) {
544
+ GGML_ASSERT(this->is_init == false);
545
+
546
+ this->is_init = true;
547
+ this->support = support;
548
+ }
549
+
550
+ private:
551
+ std::string name;
552
+ std::string name_ext;
553
+
554
+ bool is_init;
555
+ bool support;
556
+ };
557
+
558
+ // check if all ggml ops used by the sampler are supported by the backend
559
+ static bool llama_sampler_backend_support(
560
+ llama_sampler * smpl,
561
+ ggml_backend_buffer_type_t buft) {
562
+ auto * device = ggml_backend_buft_get_device(buft);
563
+ if (!device) {
564
+ // CPU backend always supported
565
+ return true;
566
+ }
567
+
568
+ ggml_init_params params = {
569
+ /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
570
+ /*.mem_buffer =*/ NULL,
571
+ /*.no_alloc =*/ true,
399
572
  };
400
573
 
401
- llama_sampler_apply(smpl, &cur_p);
574
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
575
+ if (!ctx_ptr) {
576
+ throw std::runtime_error(format("failed to create ggml context"));
577
+ }
402
578
 
403
- GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
579
+ ggml_context * ctx = ctx_ptr.get();
404
580
 
405
- auto token = cur_p.data[cur_p.selected].id;
581
+ const int64_t n = 1024*1024;
406
582
 
407
- llama_sampler_accept(smpl, token);
583
+ llama_sampler_data data = {
584
+ /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
585
+ /*.probs = */ nullptr,
586
+ /*.sampled = */ nullptr,
587
+ /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
588
+ };
408
589
 
409
- return token;
590
+ ggml_cgraph * gf = ggml_new_graph(ctx);
591
+
592
+ smpl->iface->backend_apply(smpl, ctx, gf, &data);
593
+
594
+ if (data.logits) {
595
+ ggml_build_forward_expand(gf, data.logits);
596
+ }
597
+
598
+ if (data.probs) {
599
+ ggml_build_forward_expand(gf, data.probs);
600
+ }
601
+
602
+ if (data.sampled) {
603
+ ggml_build_forward_expand(gf, data.sampled);
604
+ }
605
+
606
+ if (data.candidates) {
607
+ ggml_build_forward_expand(gf, data.candidates);
608
+ }
609
+
610
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
611
+ struct ggml_tensor * op = ggml_graph_node(gf, i);
612
+
613
+ if (!ggml_backend_dev_supports_op(device, op)) {
614
+ LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
615
+ __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
616
+
617
+ return false;
618
+ }
619
+ }
620
+
621
+ return true;
410
622
  }
411
623
 
412
624
  // sampler chain
@@ -420,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
420
632
 
421
633
  time_meas tm(chain->t_sample_us, chain->params.no_perf);
422
634
 
423
- for (auto * smpl : chain->samplers) {
424
- llama_sampler_accept(smpl, token);
635
+ for (auto & smpl : chain->samplers) {
636
+ llama_sampler_accept(smpl.ptr, token);
425
637
  }
426
638
 
427
639
  chain->n_sample++;
@@ -432,20 +644,29 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
432
644
 
433
645
  time_meas tm(chain->t_sample_us, chain->params.no_perf);
434
646
 
435
- for (auto * smpl : chain->samplers) {
436
- llama_sampler_apply(smpl, cur_p);
647
+ bool is_backend = chain->is_init;
648
+
649
+ for (auto & smpl : chain->samplers) {
650
+ if (is_backend && smpl.is_backend) {
651
+ continue;
652
+ }
653
+
654
+ is_backend = false;
655
+
656
+ if (smpl.ptr->iface->apply == nullptr) {
657
+ continue;
658
+ }
659
+
660
+ llama_sampler_apply(smpl.ptr, cur_p);
437
661
  }
438
662
  }
439
663
 
440
664
  static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
441
665
  auto * chain = (llama_sampler_chain *) smpl->ctx;
442
666
 
443
- for (auto * smpl : chain->samplers) {
444
- llama_sampler_reset(smpl);
667
+ for (auto & smpl : chain->samplers) {
668
+ llama_sampler_reset(smpl.ptr);
445
669
  }
446
-
447
- chain->t_sample_us = 0;
448
- chain->n_sample = 0;
449
670
  }
450
671
 
451
672
  static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
@@ -453,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
453
674
 
454
675
  auto * result = llama_sampler_chain_init(chain_src->params);
455
676
 
456
- for (auto * smpl : chain_src->samplers) {
457
- llama_sampler_chain_add(result, llama_sampler_clone(smpl));
677
+ for (const auto & smpl : chain_src->samplers) {
678
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
458
679
  }
459
680
 
460
681
  return result;
@@ -463,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
463
684
  static void llama_sampler_chain_free(struct llama_sampler * smpl) {
464
685
  auto * chain = (llama_sampler_chain *) smpl->ctx;
465
686
 
466
- for (auto * smpl : chain->samplers) {
467
- llama_sampler_free(smpl);
687
+ for (auto & smpl : chain->samplers) {
688
+ llama_sampler_free(smpl.ptr);
468
689
  }
469
690
 
470
691
  delete chain;
471
692
  }
472
693
 
694
+ static bool llama_sampler_chain_backend_init(
695
+ struct llama_sampler * smpl,
696
+ ggml_backend_buffer_type_t buft) {
697
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
698
+
699
+ GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
700
+
701
+ chain->is_init = true;
702
+
703
+ bool res = true;
704
+
705
+ for (auto & smpl : chain->samplers) {
706
+ bool res_cur = true;
707
+
708
+ // to be able to run a sampler on the backend, it has to:
709
+ // - have the .backend_init() API implemented
710
+ // - return true during .backend_init()
711
+ if (smpl.ptr->iface->backend_init) {
712
+ if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
713
+ res_cur = false;
714
+ }
715
+ } else {
716
+ res_cur = false;
717
+ }
718
+
719
+ smpl.is_backend = res_cur;
720
+
721
+ res = res && res_cur;
722
+ }
723
+
724
+ return res;
725
+ }
726
+
727
+ static void llama_sampler_chain_backend_accept(
728
+ struct llama_sampler * smpl,
729
+ ggml_context * ctx,
730
+ ggml_cgraph * gf,
731
+ struct ggml_tensor * selected_token) {
732
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
733
+
734
+ for (auto & smpl : chain->samplers) {
735
+ if (!smpl.is_backend) {
736
+ break;
737
+ }
738
+
739
+ if (smpl.ptr->iface->backend_accept) {
740
+ smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
741
+ }
742
+ }
743
+ }
744
+
745
+ static void llama_sampler_chain_backend_apply(
746
+ struct llama_sampler * smpl,
747
+ struct ggml_context * ctx,
748
+ struct ggml_cgraph * gf,
749
+ struct llama_sampler_data * data) {
750
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
751
+
752
+ GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
753
+
754
+ for (auto & smpl : chain->samplers) {
755
+ if (!smpl.is_backend) {
756
+ break;
757
+ }
758
+
759
+ if (smpl.ptr->iface->backend_apply) {
760
+ smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
761
+ }
762
+ }
763
+ }
764
+
765
+ static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
766
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
767
+
768
+ for (auto & smpl : chain->samplers) {
769
+ if (!smpl.is_backend) {
770
+ break;
771
+ }
772
+
773
+ if (smpl.ptr->iface->backend_set_input) {
774
+ smpl.ptr->iface->backend_set_input(smpl.ptr);
775
+ }
776
+ }
777
+ }
778
+
473
779
  static struct llama_sampler_i llama_sampler_chain_i = {
474
- /* .name = */ llama_sampler_chain_name,
475
- /* .accept = */ llama_sampler_chain_accept,
476
- /* .apply = */ llama_sampler_chain_apply,
477
- /* .reset = */ llama_sampler_chain_reset,
478
- /* .clone = */ llama_sampler_chain_clone,
479
- /* .free = */ llama_sampler_chain_free,
780
+ /* .name = */ llama_sampler_chain_name,
781
+ /* .accept = */ llama_sampler_chain_accept,
782
+ /* .apply = */ llama_sampler_chain_apply,
783
+ /* .reset = */ llama_sampler_chain_reset,
784
+ /* .clone = */ llama_sampler_chain_clone,
785
+ /* .free = */ llama_sampler_chain_free,
786
+ /* .backend_init = */ llama_sampler_chain_backend_init,
787
+ /* .backend_accept = */ llama_sampler_chain_backend_accept,
788
+ /* .backend_apply = */ llama_sampler_chain_backend_apply,
789
+ /* .backend_set_input = */ llama_sampler_chain_backend_set_input,
480
790
  };
481
791
 
482
792
  struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@@ -484,26 +794,113 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
484
794
  /* .iface = */ &llama_sampler_chain_i,
485
795
  /* .ctx = */ new llama_sampler_chain {
486
796
  /* .params = */ params,
797
+ /* .is_init = */ false,
487
798
  /* .samplers = */ {},
799
+ /* .cur = */ {},
488
800
  /* .t_sample_us = */ 0,
489
801
  /* .n_sample = */ 0,
490
802
  }
491
803
  );
492
804
  }
493
805
 
806
+ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
807
+ const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
808
+ const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
809
+ const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
810
+ const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
811
+
812
+ // If a backend sampler has already sampled a token, return it.
813
+ if (sampled_token != LLAMA_TOKEN_NULL) {
814
+ LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
815
+ return sampled_token;
816
+ }
817
+
818
+ const llama_model * model = llama_get_model(ctx);
819
+ const llama_vocab * vocab = llama_model_get_vocab(model);
820
+
821
+ const int n_vocab = llama_vocab_n_tokens(vocab);
822
+
823
+ // use pre-allocated buffer from chain if available, otherwise allocate locally
824
+ std::vector<llama_token_data> * cur_ptr;
825
+ std::vector<llama_token_data> cur_local;
826
+
827
+ if (smpl->iface == &llama_sampler_chain_i) {
828
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
829
+ cur_ptr = &chain->cur;
830
+ } else {
831
+ cur_ptr = &cur_local;
832
+ }
833
+
834
+ auto & cur = *cur_ptr;
835
+
836
+ if (sampled_probs) {
837
+ const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
838
+ cur.resize(sampled_probs_count);
839
+ for (uint32_t i = 0; i < sampled_probs_count; ++i) {
840
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
841
+ }
842
+ } else if (sampled_logits) {
843
+ const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
844
+ cur.resize(sampled_logits_count);
845
+ for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
846
+ cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
847
+ }
848
+ } else {
849
+ const auto * logits = llama_get_logits_ith(ctx, idx);
850
+ GGML_ASSERT(logits != nullptr);
851
+ cur.resize(n_vocab);
852
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
853
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
854
+ }
855
+ }
856
+
857
+ llama_token_data_array cur_p = {
858
+ /* .data = */ cur.data(),
859
+ /* .size = */ cur.size(),
860
+ /* .selected = */ -1,
861
+ /* .sorted = */ false,
862
+ };
863
+
864
+ llama_sampler_apply(smpl, &cur_p);
865
+
866
+ GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
867
+
868
+ auto token = cur_p.data[cur_p.selected].id;
869
+
870
+ llama_sampler_accept(smpl, token);
871
+
872
+ return token;
873
+ }
874
+
875
+
494
876
  void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
495
877
  auto * p = (llama_sampler_chain *) chain->ctx;
496
- p->samplers.push_back(smpl);
878
+ p->samplers.push_back({
879
+ /* .is_backend = */ false,
880
+ /* .ptr = */ smpl,
881
+ });
497
882
  }
498
883
 
499
- struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
884
+ struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
885
+ if (chain == nullptr) {
886
+ return nullptr;
887
+ }
888
+
889
+ if (chain->iface != &llama_sampler_chain_i) {
890
+ return nullptr;
891
+ }
892
+
893
+ if (i == -1) {
894
+ return chain;
895
+ }
896
+
500
897
  const auto * p = (const llama_sampler_chain *) chain->ctx;
501
898
 
502
899
  if (i < 0 || (size_t) i >= p->samplers.size()) {
503
900
  return nullptr;
504
901
  }
505
902
 
506
- return p->samplers[i];
903
+ return p->samplers[i].ptr;
507
904
  }
508
905
 
509
906
  struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
@@ -513,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
513
910
  return nullptr;
514
911
  }
515
912
 
516
- auto * result = p->samplers[i];
913
+ auto * result = p->samplers[i].ptr;
517
914
  p->samplers.erase(p->samplers.begin() + i);
518
915
 
519
916
  return result;
@@ -531,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
531
928
 
532
929
  // greedy
533
930
 
534
- static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
535
- return "greedy";
931
+ struct llama_sampler_greedy : public llama_sampler_backend {
932
+ };
933
+
934
+ static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
935
+ auto * sctx = (llama_sampler_greedy *) smpl->ctx;
936
+ return sctx->get_name();
937
+ }
938
+
939
+ static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
940
+ auto * ctx = (llama_sampler_greedy *) smpl->ctx;
941
+ GGML_UNUSED(ctx);
942
+ }
943
+
944
+ static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
945
+ const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
946
+ auto * result = llama_sampler_init_greedy();
947
+
948
+ // copy the state
949
+ {
950
+ auto * result_ctx = (llama_sampler_greedy *) result->ctx;
951
+
952
+ GGML_UNUSED(ctx);
953
+ GGML_UNUSED(result_ctx);
954
+ }
955
+
956
+ return result;
957
+ }
958
+
959
+ static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
960
+ delete (llama_sampler_greedy *) smpl->ctx;
536
961
  }
537
962
 
538
963
  static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
@@ -544,41 +969,150 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
544
969
  }
545
970
  }
546
971
 
972
+ static bool llama_sampler_greedy_backend_init(
973
+ struct llama_sampler * smpl,
974
+ ggml_backend_buffer_type_t buft) {
975
+ auto * sctx = (llama_sampler_greedy *) smpl->ctx;
976
+
977
+ const bool res = llama_sampler_backend_support(smpl, buft);
978
+
979
+ sctx->init(res);
980
+
981
+ return res;
982
+ }
983
+
984
+ static void llama_sampler_greedy_backend_apply(
985
+ struct llama_sampler * smpl,
986
+ struct ggml_context * ctx,
987
+ struct ggml_cgraph * gf,
988
+ struct llama_sampler_data * data) {
989
+ GGML_UNUSED(gf);
990
+ GGML_UNUSED(smpl);
991
+
992
+ struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
993
+ ggml_set_name(curl, "greedy_argmax");
994
+
995
+ data->sampled = curl;
996
+ }
997
+
547
998
  static struct llama_sampler_i llama_sampler_greedy_i = {
548
- /* .name = */ llama_sampler_greedy_name,
549
- /* .accept = */ nullptr,
550
- /* .apply = */ llama_sampler_greedy_apply,
551
- /* .reset = */ nullptr,
552
- /* .clone = */ nullptr,
553
- /* .free = */ nullptr,
999
+ /* .name = */ llama_sampler_greedy_name,
1000
+ /* .accept = */ nullptr,
1001
+ /* .apply = */ llama_sampler_greedy_apply,
1002
+ /* .reset = */ llama_sampler_greedy_reset,
1003
+ /* .clone = */ llama_sampler_greedy_clone,
1004
+ /* .free = */ llama_sampler_greedy_free,
1005
+ /* .backend_init = */ llama_sampler_greedy_backend_init,
1006
+ /* .backend_accept = */ nullptr,
1007
+ /* .backend_apply = */ llama_sampler_greedy_backend_apply,
1008
+ /* .backend_set_input = */ nullptr,
554
1009
  };
555
1010
 
556
1011
  struct llama_sampler * llama_sampler_init_greedy() {
557
1012
  return llama_sampler_init(
558
1013
  /* .iface = */ &llama_sampler_greedy_i,
559
- /* .ctx = */ nullptr
1014
+ /* .ctx = */ new llama_sampler_greedy {
1015
+ ("greedy"),
1016
+ }
560
1017
  );
561
1018
  }
562
1019
 
563
1020
  // dist
564
1021
 
565
- struct llama_sampler_dist {
1022
+ struct llama_sampler_dist : public llama_sampler_backend {
566
1023
  const uint32_t seed;
567
1024
  uint32_t seed_cur;
568
1025
 
569
- std::mt19937 rng;
570
- };
1026
+ std::mt19937 rng;
1027
+
1028
+ // backend input
1029
+ struct ggml_tensor * inp_uniform;
1030
+
1031
+ ggml_context_ptr inp_ctx;
1032
+ ggml_backend_buffer_ptr inp_buf;
1033
+ };
1034
+
1035
+ static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
1036
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1037
+ return sctx->get_name();
1038
+ }
1039
+
1040
+ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1041
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
1042
+
1043
+ // edge cases
1044
+ if (cur_p->size == 0) {
1045
+ cur_p->selected = -1;
1046
+ return;
1047
+ }
1048
+
1049
+ cur_p->selected = 0;
1050
+
1051
+ if (cur_p->size == 1) {
1052
+ cur_p->data[0].p = 1.0f;
1053
+ return;
1054
+ }
1055
+
1056
+ // max logit for numerical stability
1057
+ float max_l = cur_p->data[0].logit;
1058
+ if (!cur_p->sorted) {
1059
+ for (size_t i = 1; i < cur_p->size; ++i) {
1060
+ max_l = std::max(max_l, cur_p->data[i].logit);
1061
+ }
1062
+ }
1063
+
1064
+ // apply softmax to obtain the probabilities
1065
+ double sum_cum = 0.0f;
1066
+ for (size_t i = 0; i < cur_p->size; ++i) {
1067
+ float p = expf(cur_p->data[i].logit - max_l);
1068
+ cur_p->data[i].p = p;
1069
+ sum_cum += p;
1070
+ }
1071
+
1072
+ #if 1
1073
+ // sample from the obtained probabilities and normalize the probs in a single pass
1074
+ // this is ~3x faster on Mac with full gpt-oss vocab than the version below
1075
+ //
1076
+ std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1077
+ const double rnd = dist(ctx->rng);
1078
+
1079
+ double sum_run = 0.0f;
1080
+ const double sum_tgt = sum_cum*rnd;
571
1081
 
572
- static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
573
- return "dist";
574
- }
1082
+ bool found = false;
1083
+ for (size_t i = 0; i < cur_p->size; ++i) {
1084
+ if (!found) {
1085
+ // accumulate probs until we reach the target sum
1086
+ sum_run += cur_p->data[i].p;
1087
+ if (sum_run >= sum_tgt) {
1088
+ cur_p->selected = i;
1089
+ found = true;
1090
+ }
1091
+ }
575
1092
 
576
- static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
577
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
1093
+ // normalize probs
1094
+ cur_p->data[i].p /= sum_cum;
1095
+ }
578
1096
 
579
- llama_sampler_softmax_impl(cur_p);
1097
+ // fallback to the last token (don't think this can happen)
1098
+ assert(found);
1099
+ if (!found) {
1100
+ cur_p->selected = cur_p->size - 1;
1101
+ }
1102
+ #else
1103
+ // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
1104
+ for (size_t i = 0; i < cur_p->size; ++i) {
1105
+ cur_p->data[i].p /= sum_cum;
1106
+ }
580
1107
 
581
1108
  cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
1109
+ #endif
1110
+ }
1111
+
1112
+ static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
1113
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
1114
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1115
+ ctx->rng.seed(ctx->seed_cur);
582
1116
  }
583
1117
 
584
1118
  static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
@@ -595,75 +1129,158 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
595
1129
  return result;
596
1130
  }
597
1131
 
598
- static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
599
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
600
- ctx->seed_cur = get_rng_seed(ctx->seed);
601
- ctx->rng.seed(ctx->seed_cur);
602
- }
603
-
604
1132
  static void llama_sampler_dist_free(struct llama_sampler * smpl) {
605
1133
  delete (llama_sampler_dist *) smpl->ctx;
606
1134
  }
607
1135
 
608
- static struct llama_sampler_i llama_sampler_dist_i = {
609
- /* .name = */ llama_sampler_dist_name,
610
- /* .accept = */ nullptr,
611
- /* .apply = */ llama_sampler_dist_apply,
612
- /* .reset = */ llama_sampler_dist_reset,
613
- /* .clone = */ llama_sampler_dist_clone,
614
- /* .free = */ llama_sampler_dist_free,
615
- };
1136
+ static bool llama_sampler_dist_backend_init(
1137
+ struct llama_sampler * smpl,
1138
+ ggml_backend_buffer_type_t buft) {
1139
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
616
1140
 
617
- struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
618
- auto seed_cur = get_rng_seed(seed);
619
- return llama_sampler_init(
620
- /* .iface = */ &llama_sampler_dist_i,
621
- /* .ctx = */ new llama_sampler_dist {
622
- /* .seed = */ seed,
623
- /* .seed_cur = */ seed_cur,
624
- /* .rng = */ std::mt19937(seed_cur),
625
- }
626
- );
1141
+ // allocate inputs
1142
+ {
1143
+ ggml_init_params params = {
1144
+ /*.mem_size =*/ ggml_tensor_overhead(),
1145
+ /*.mem_buffer =*/ nullptr,
1146
+ /*.no_alloc =*/ true,
1147
+ };
1148
+
1149
+ sctx->inp_ctx.reset(ggml_init(params));
1150
+
1151
+ // Create the uniform random scalar input tensor. This will be set by
1152
+ // llama_sampler_dist_backend_set_input after this graph is built.
1153
+ sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
1154
+ ggml_set_name (sctx->inp_uniform, "uniform");
1155
+ ggml_set_input(sctx->inp_uniform);
1156
+
1157
+ // Allocate all tensors from our context to the backend
1158
+ sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
1159
+
1160
+ ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
1161
+ }
1162
+
1163
+ const bool res = llama_sampler_backend_support(smpl, buft);
1164
+
1165
+ sctx->init(res);
1166
+
1167
+ if (!res) {
1168
+ sctx->inp_ctx.reset(nullptr);
1169
+ sctx->inp_buf.reset(nullptr);
1170
+ }
1171
+
1172
+ return res;
627
1173
  }
628
1174
 
629
- // softmax
1175
+ static void llama_sampler_dist_backend_apply(
1176
+ struct llama_sampler * smpl,
1177
+ struct ggml_context * ctx,
1178
+ struct ggml_cgraph * gf,
1179
+ struct llama_sampler_data * data) {
1180
+ GGML_UNUSED(gf);
1181
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1182
+
1183
+ struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
1184
+ ggml_set_name(probs, "dist_probs");
1185
+
1186
+ struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
1187
+ ggml_set_name(cumsum, "dist_cumsum");
1188
+
1189
+ // The uniform tensor has a random value and we subtract this tensor with
1190
+ // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
1191
+ // Recall that each entry in cumsum is the cumulative probability up to that
1192
+ // index so values stay negative while the cumulative total is below the
1193
+ // random value, and become zero/positive once the threshold is crossed.
1194
+ struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
1195
+ ggml_set_name(diff, "dist_cumsum");
1196
+
1197
+ // The ggml_step function produces a tensor where entries are 1 if the
1198
+ // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
1199
+ // the index where the cumulative probability exceeds the random value are 0,
1200
+ // and all entries after that are 1.
1201
+ struct ggml_tensor * mask = ggml_step(ctx, diff);
1202
+ ggml_set_name(mask, "dist_mask");
1203
+
1204
+ // Taking the sum of the mask gives us the sum of elements after the threshold
1205
+ // we are interested in.
1206
+ struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1207
+ ggml_set_name(idxf, "dist_index_f32");
630
1208
 
631
- static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
632
- return "softmax";
1209
+ // Use ggml_scale_bias to scale the index value by -1 and then add the size
1210
+ // of the mask to that value so we get the correct index ((-1 * idxf) + n).
1211
+ struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
1212
+ ggml_set_name(idx, "dist_index_i32");
1213
+
1214
+ // Map back to original vocab ids if a candidates tensor is available.
1215
+ struct ggml_tensor * sampled_token = idx;
1216
+ if (data->candidates != nullptr) {
1217
+ struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
1218
+
1219
+ sampled_token = ggml_get_rows(ctx, candidates, idx);
1220
+ ggml_set_name(sampled_token, "dist_sampled_token");
1221
+ }
1222
+
1223
+ data->sampled = sampled_token;
1224
+ data->probs = probs;
633
1225
  }
634
1226
 
635
- static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
636
- llama_sampler_softmax_impl(cur_p);
1227
+ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
1228
+ auto * sctx = (llama_sampler_dist *) smpl->ctx;
1229
+ GGML_ASSERT(sctx->inp_uniform != nullptr);
1230
+
1231
+ // We sample in double precision and cast to float to match rnd numbers of
1232
+ // llama_dampler_dist which uses double precision (sampling from
1233
+ // std::uniform_real_distribution<double> and
1234
+ // std::uniform_real_distribution<float> with same rng will produce
1235
+ // different sequences).
1236
+ std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1237
+ const float rnd = dist(sctx->rng);
1238
+
1239
+ ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
637
1240
  }
638
1241
 
639
- static struct llama_sampler_i llama_sampler_softmax_i = {
640
- /* .name = */ llama_sampler_softmax_name,
641
- /* .accept = */ nullptr,
642
- /* .apply = */ llama_sampler_softmax_apply,
643
- /* .reset = */ nullptr,
644
- /* .clone = */ nullptr,
645
- /* .free = */ nullptr,
1242
+ static struct llama_sampler_i llama_sampler_dist_i = {
1243
+ /* .name = */ llama_sampler_dist_name,
1244
+ /* .accept = */ nullptr,
1245
+ /* .apply = */ llama_sampler_dist_apply,
1246
+ /* .reset = */ llama_sampler_dist_reset,
1247
+ /* .clone = */ llama_sampler_dist_clone,
1248
+ /* .free = */ llama_sampler_dist_free,
1249
+ /* .backend_init = */ llama_sampler_dist_backend_init,
1250
+ /* .backend_accept = */ nullptr,
1251
+ /* .backend_apply = */ llama_sampler_dist_backend_apply,
1252
+ /* .backend_set_input = */ llama_sampler_dist_backend_set_input,
646
1253
  };
647
1254
 
648
- struct llama_sampler * llama_sampler_init_softmax() {
1255
+ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
1256
+ auto seed_cur = get_rng_seed(seed);
649
1257
  return llama_sampler_init(
650
- /* .iface = */ &llama_sampler_softmax_i,
651
- /* .ctx = */ nullptr
1258
+ /* .iface = */ &llama_sampler_dist_i,
1259
+ /* .ctx = */ new llama_sampler_dist {
1260
+ ("dist"),
1261
+ /* .seed = */ seed,
1262
+ /* .seed_cur = */ seed_cur,
1263
+ /* .rng = */ std::mt19937(seed_cur),
1264
+ /* .inp_uniform = */ nullptr,
1265
+ /* .inp_ctx = */ nullptr,
1266
+ /* .inp_buf = */ nullptr,
1267
+ }
652
1268
  );
653
1269
  }
654
1270
 
655
1271
  // top-k
656
1272
 
657
- struct llama_sampler_top_k {
1273
+ struct llama_sampler_top_k : public llama_sampler_backend {
658
1274
  const int32_t k;
659
1275
  };
660
1276
 
661
- static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
662
- return "top-k";
1277
+ static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
1278
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1279
+ return sctx->get_name();
663
1280
  }
664
1281
 
665
1282
  static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
666
- const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
1283
+ auto * ctx = (llama_sampler_top_k *) smpl->ctx;
667
1284
  llama_sampler_top_k_impl(cur_p, ctx->k);
668
1285
  }
669
1286
 
@@ -676,19 +1293,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
676
1293
  delete (llama_sampler_top_k *) smpl->ctx;
677
1294
  }
678
1295
 
1296
+ static bool llama_sampler_top_k_backend_init(
1297
+ struct llama_sampler * smpl,
1298
+ ggml_backend_buffer_type_t buft) {
1299
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1300
+
1301
+ const bool res = llama_sampler_backend_support(smpl, buft);
1302
+
1303
+ sctx->init(res);
1304
+
1305
+ return res;
1306
+ }
1307
+
1308
+ static void llama_sampler_top_k_backend_apply(
1309
+ struct llama_sampler * smpl,
1310
+ struct ggml_context * ctx,
1311
+ struct ggml_cgraph * gf,
1312
+ struct llama_sampler_data * data) {
1313
+ auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1314
+
1315
+ struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
1316
+ ggml_set_name(top_k, "top_k");
1317
+
1318
+ if (data->candidates) {
1319
+ struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1320
+ data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
1321
+ data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
1322
+ ggml_set_name(data->candidates, "top_k_candidates");
1323
+ } else {
1324
+ data->candidates = top_k;
1325
+ }
1326
+
1327
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1328
+ struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
1329
+ data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
1330
+ ggml_set_name(top_k_rows, "top_k_rows");
1331
+
1332
+ GGML_UNUSED(gf);
1333
+ }
1334
+
679
1335
  static struct llama_sampler_i llama_sampler_top_k_i = {
680
- /* .name = */ llama_sampler_top_k_name,
681
- /* .accept = */ nullptr,
682
- /* .apply = */ llama_sampler_top_k_apply,
683
- /* .reset = */ nullptr,
684
- /* .clone = */ llama_sampler_top_k_clone,
685
- /* .free = */ llama_sampler_top_k_free,
1336
+ /* .name = */ llama_sampler_top_k_name,
1337
+ /* .accept = */ nullptr,
1338
+ /* .apply = */ llama_sampler_top_k_apply,
1339
+ /* .reset = */ nullptr,
1340
+ /* .clone = */ llama_sampler_top_k_clone,
1341
+ /* .free = */ llama_sampler_top_k_free,
1342
+ /* .backend_init = */ llama_sampler_top_k_backend_init,
1343
+ /* .backend_accept = */ nullptr,
1344
+ /* .backend_apply = */ llama_sampler_top_k_backend_apply,
1345
+ /* .backend_set_input = */ nullptr,
686
1346
  };
687
1347
 
688
1348
  struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
1349
+ const bool is_empty = (k <= 0);
1350
+
1351
+ if (is_empty) {
1352
+ return llama_sampler_init_empty("?top-k");
1353
+ }
1354
+
689
1355
  return llama_sampler_init(
690
1356
  /* .iface = */ &llama_sampler_top_k_i,
691
1357
  /* .ctx = */ new llama_sampler_top_k {
1358
+ ("top-k"),
692
1359
  /* .k = */ k,
693
1360
  }
694
1361
  );
@@ -696,30 +1363,48 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
696
1363
 
697
1364
  // top-p
698
1365
 
699
- struct llama_sampler_top_p {
1366
+ struct llama_sampler_top_p : public llama_sampler_backend {
700
1367
  const float p;
701
1368
  const size_t min_keep;
1369
+
1370
+ std::vector<llama_token_data> buf_sort;
702
1371
  };
703
1372
 
704
- static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
705
- return "top-p";
1373
+ static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
1374
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1375
+ return sctx->get_name();
706
1376
  }
707
1377
 
708
1378
  static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
709
- const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
1379
+ auto * ctx = (llama_sampler_top_p *) smpl->ctx;
710
1380
 
711
1381
  if (ctx->p >= 1.0f) {
712
1382
  return;
713
1383
  }
714
1384
 
715
- llama_sampler_softmax_impl(cur_p);
1385
+ llama_sampler_softmax_impl(cur_p, false);
1386
+
1387
+ size_t k = cur_p->size;
1388
+ auto * pdata = cur_p->data;
1389
+
1390
+ auto & buf_sort = ctx->buf_sort;
1391
+
1392
+ // if not sorted, try adaptive top-k sorting
1393
+ if (!cur_p->sorted && cur_p->size > 1024) {
1394
+ k = std::min<size_t>(256, cur_p->size);
1395
+ llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
1396
+ pdata = buf_sort.data();
1397
+ } else if (!cur_p->sorted) {
1398
+ // small candidates -> sort inplace
1399
+ llama_token_data_array_partial_sort_inplace(cur_p, k);
1400
+ }
716
1401
 
717
1402
  // Compute the cumulative probabilities
718
1403
  float cum_sum = 0.0f;
719
1404
  size_t last_idx = cur_p->size;
720
1405
 
721
1406
  for (size_t i = 0; i < cur_p->size; ++i) {
722
- cum_sum += cur_p->data[i].p;
1407
+ cum_sum += pdata[i].p;
723
1408
 
724
1409
  // Check if the running sum is at least p or if we have kept at least min_keep tokens
725
1410
  // we set the last index to i+1 to indicate that the current iterate should be included in the set
@@ -727,9 +1412,21 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
727
1412
  last_idx = i + 1;
728
1413
  break;
729
1414
  }
1415
+
1416
+ // we exceeded the current top-k heuristic -> increase k and continue
1417
+ if (!cur_p->sorted && i == k - 1) {
1418
+ k = cur_p->size;
1419
+ llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
1420
+ pdata = buf_sort.data();
1421
+ }
730
1422
  }
731
1423
 
732
1424
  // Resize the output vector to keep only the top-p tokens
1425
+ if (!cur_p->sorted) {
1426
+ std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data);
1427
+ cur_p->sorted = true;
1428
+ }
1429
+
733
1430
  cur_p->size = last_idx;
734
1431
  }
735
1432
 
@@ -742,38 +1439,139 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
742
1439
  delete (llama_sampler_top_p *) smpl->ctx;
743
1440
  }
744
1441
 
1442
+ static bool llama_sampler_top_p_backend_init(
1443
+ struct llama_sampler * smpl,
1444
+ ggml_backend_buffer_type_t buft) {
1445
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1446
+
1447
+ const bool res = llama_sampler_backend_support(smpl, buft);
1448
+
1449
+ sctx->init(res);
1450
+
1451
+ return res;
1452
+ }
1453
+
1454
+ static void llama_sampler_top_p_backend_apply(
1455
+ struct llama_sampler * smpl,
1456
+ struct ggml_context * ctx,
1457
+ struct ggml_cgraph * gf,
1458
+ struct llama_sampler_data * data) {
1459
+ auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1460
+
1461
+ auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
1462
+ GGML_ASSERT(ggml_nrows(a) == 1);
1463
+ struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
1464
+ struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
1465
+ return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
1466
+ };
1467
+
1468
+ // Get the sorted logits in descending order.
1469
+ struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
1470
+ ggml_set_name(sorted_idx, "top_p_sorted_idx");
1471
+
1472
+ // Do the sorting via reshape + get_rows
1473
+ struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
1474
+ ggml_set_name(sorted_logits, "top_p_sorted_logits");
1475
+
1476
+ struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
1477
+ ggml_set_name(softmax, "top_p_softmax");
1478
+
1479
+ // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
1480
+ if (data->candidates) {
1481
+ data->candidates = ggml_sort(data->candidates, sorted_idx);
1482
+ } else {
1483
+ data->candidates = sorted_idx;
1484
+ }
1485
+ ggml_set_name(data->candidates, "top_p_candidates");
1486
+
1487
+ // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
1488
+ struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
1489
+ ggml_set_name(cdf, "top_p_cdf");
1490
+
1491
+ // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
1492
+ struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
1493
+ ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
1494
+
1495
+ struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
1496
+ ggml_set_name(mask, "top_p_mask");
1497
+
1498
+ // Taking the sum of the mask gives us the sum of elements after the threshold
1499
+ // we are interested in.
1500
+ struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1501
+ ggml_set_name(idxf, "top_p_index_f32");
1502
+
1503
+ // prevent out-of-bounds access
1504
+ idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
1505
+
1506
+ // construct ones tensor to set the value in the mask
1507
+ struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
1508
+ ggml_set_name(ones, "top_p_ones");
1509
+
1510
+ // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
1511
+ struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
1512
+
1513
+ mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
1514
+ mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
1515
+
1516
+ // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
1517
+ // top_p_bias = (mask * 1e9f) - 1e9f.
1518
+ // So entries in the mask that we want to discard will become -1e9f, and
1519
+ // others will be 0 (meaning that will not effect the logits).
1520
+ const float large_val = 1e9f;
1521
+ struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
1522
+ ggml_set_name(top_p_bias, "top_p_bias");
1523
+
1524
+ data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
1525
+ ggml_set_name(data->logits, "top_p_logits");
1526
+
1527
+ GGML_UNUSED(gf);
1528
+ }
1529
+
745
1530
  static struct llama_sampler_i llama_sampler_top_p_i = {
746
- /* .name = */ llama_sampler_top_p_name,
747
- /* .accept = */ nullptr,
748
- /* .apply = */ llama_sampler_top_p_apply,
749
- /* .reset = */ nullptr,
750
- /* .clone = */ llama_sampler_top_p_clone,
751
- /* .free = */ llama_sampler_top_p_free,
1531
+ /* .name = */ llama_sampler_top_p_name,
1532
+ /* .accept = */ nullptr,
1533
+ /* .apply = */ llama_sampler_top_p_apply,
1534
+ /* .reset = */ nullptr,
1535
+ /* .clone = */ llama_sampler_top_p_clone,
1536
+ /* .free = */ llama_sampler_top_p_free,
1537
+ /* .backend_init = */ llama_sampler_top_p_backend_init,
1538
+ /* .backend_accept = */ nullptr,
1539
+ /* .backend_apply = */ llama_sampler_top_p_backend_apply,
1540
+ /* .backend_set_input = */ nullptr,
752
1541
  };
753
1542
 
754
1543
  struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
1544
+ const bool is_empty = p >= 1.0f;
1545
+
1546
+ if (is_empty) {
1547
+ return llama_sampler_init_empty("?top-p");
1548
+ }
1549
+
755
1550
  return llama_sampler_init(
756
1551
  /* .iface = */ &llama_sampler_top_p_i,
757
1552
  /* .ctx = */ new llama_sampler_top_p {
1553
+ ("top-p"),
758
1554
  /* .p = */ p,
759
1555
  /* .min_keep = */ min_keep,
1556
+ /* .buf_sort = */ {},
760
1557
  }
761
1558
  );
762
1559
  }
763
1560
 
764
1561
  // min-p
765
1562
 
766
- struct llama_sampler_min_p {
1563
+ struct llama_sampler_min_p : public llama_sampler_backend {
767
1564
  const float p;
768
1565
  const size_t min_keep;
769
1566
  };
770
1567
 
771
- static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
772
- return "min-p";
1568
+ static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
1569
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1570
+ return sctx->get_name();
773
1571
  }
774
1572
 
775
1573
  static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
776
- const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
1574
+ auto * ctx = (llama_sampler_min_p *) smpl->ctx;
777
1575
 
778
1576
  if (ctx->p <= 0.0f || !cur_p->size) {
779
1577
  return;
@@ -799,7 +1597,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
799
1597
 
800
1598
  // if we have enough values the operation was a success
801
1599
  if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
802
- memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
1600
+ std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
803
1601
  cur_p->size = filtered_tokens.size();
804
1602
  min_p_applied = true;
805
1603
  }
@@ -809,10 +1607,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
809
1607
  if (!min_p_applied) {
810
1608
  // Sort the logits in descending order
811
1609
  if (!cur_p->sorted) {
812
- std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
813
- return a.logit > b.logit;
814
- });
815
- cur_p->sorted = true;
1610
+ llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
816
1611
  }
817
1612
 
818
1613
  const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
@@ -838,19 +1633,85 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
838
1633
  delete (llama_sampler_min_p *) smpl->ctx;
839
1634
  }
840
1635
 
1636
+ static bool llama_sampler_min_p_backend_init(
1637
+ struct llama_sampler * smpl,
1638
+ ggml_backend_buffer_type_t buft) {
1639
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1640
+
1641
+ const bool res = llama_sampler_backend_support(smpl, buft);
1642
+
1643
+ sctx->init(res);
1644
+
1645
+ return res;
1646
+ }
1647
+
1648
+ static void llama_sampler_min_p_backend_apply(
1649
+ struct llama_sampler * smpl,
1650
+ struct ggml_context * ctx,
1651
+ struct ggml_cgraph * gf,
1652
+ struct llama_sampler_data * data) {
1653
+ auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1654
+
1655
+ struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1656
+ ggml_set_name(max_idx, "max_idx");
1657
+
1658
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1659
+ ggml_set_name(logits_rows, "logits_rows");
1660
+
1661
+ struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
1662
+ ggml_set_name(max_logit, "max_logit");
1663
+
1664
+ // Calculate the threshold value.
1665
+ struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
1666
+ ggml_set_name(threshold, "min_p_threshold");
1667
+
1668
+ // Subtract the threshold from logits.
1669
+ struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
1670
+
1671
+ // Create a mask where logits below the threshold are 0 (discard),
1672
+ // and others are 1 (keep).
1673
+ struct ggml_tensor * mask = ggml_step(ctx, sub);
1674
+ ggml_set_name(mask, "min_p_mask");
1675
+
1676
+ // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
1677
+ // min_p_bias = (mask * 1e9f) - 1e9f.
1678
+ // So entries in the mask that we want to discard will become -1e9f, and
1679
+ // others will be 0 (meaning that will not effect the logits).
1680
+ const float large_val = 1e9f;
1681
+ struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
1682
+ ggml_set_name(min_p_bias, "min_p_bias");
1683
+
1684
+ // Add the min_p bias to the logits.
1685
+ data->logits = ggml_add(ctx, data->logits, min_p_bias);
1686
+ ggml_set_name(data->logits, "min_p_logits");
1687
+
1688
+ GGML_UNUSED(gf);
1689
+ }
1690
+
841
1691
  static struct llama_sampler_i llama_sampler_min_p_i = {
842
- /* .name = */ llama_sampler_min_p_name,
843
- /* .accept = */ nullptr,
844
- /* .apply = */ llama_sampler_min_p_apply,
845
- /* .reset = */ nullptr,
846
- /* .clone = */ llama_sampler_min_p_clone,
847
- /* .free = */ llama_sampler_min_p_free,
1692
+ /* .name = */ llama_sampler_min_p_name,
1693
+ /* .accept = */ nullptr,
1694
+ /* .apply = */ llama_sampler_min_p_apply,
1695
+ /* .reset = */ nullptr,
1696
+ /* .clone = */ llama_sampler_min_p_clone,
1697
+ /* .free = */ llama_sampler_min_p_free,
1698
+ /* .backend_init = */ llama_sampler_min_p_backend_init,
1699
+ /* .backend_accept = */ nullptr,
1700
+ /* .backend_apply = */ llama_sampler_min_p_backend_apply,
1701
+ /* .backend_set_input = */ nullptr,
848
1702
  };
849
1703
 
850
1704
  struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
1705
+ const bool is_empty = (p <= 0.0f);
1706
+
1707
+ if (is_empty) {
1708
+ return llama_sampler_init_empty("?min-p");
1709
+ }
1710
+
851
1711
  return llama_sampler_init(
852
1712
  /* .iface = */ &llama_sampler_min_p_i,
853
1713
  /* .ctx = */ new llama_sampler_min_p {
1714
+ ("min-p"),
854
1715
  /* .p = */ p,
855
1716
  /* .min_keep = */ min_keep,
856
1717
  }
@@ -869,7 +1730,7 @@ static const char * llama_sampler_typical_name(const struct llama_sampler * /*sm
869
1730
  }
870
1731
 
871
1732
  static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
872
- const auto * ctx = (llama_sampler_typical *) smpl->ctx;
1733
+ auto * ctx = (llama_sampler_typical *) smpl->ctx;
873
1734
 
874
1735
  // Reference implementation:
875
1736
  // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
@@ -878,7 +1739,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
878
1739
  }
879
1740
 
880
1741
  // Compute the softmax of logits and calculate entropy
881
- llama_sampler_softmax_impl(cur_p);
1742
+ llama_sampler_softmax_impl(cur_p, true);
882
1743
 
883
1744
  float entropy = 0.0f;
884
1745
  for (size_t i = 0; i < cur_p->size; ++i) {
@@ -938,15 +1799,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
938
1799
  }
939
1800
 
940
1801
  static struct llama_sampler_i llama_sampler_typical_i = {
941
- /* .name = */ llama_sampler_typical_name,
942
- /* .accept = */ nullptr,
943
- /* .apply = */ llama_sampler_typical_apply,
944
- /* .reset = */ nullptr,
945
- /* .clone = */ llama_sampler_typical_clone,
946
- /* .free = */ llama_sampler_typical_free,
1802
+ /* .name = */ llama_sampler_typical_name,
1803
+ /* .accept = */ nullptr,
1804
+ /* .apply = */ llama_sampler_typical_apply,
1805
+ /* .reset = */ nullptr,
1806
+ /* .clone = */ llama_sampler_typical_clone,
1807
+ /* .free = */ llama_sampler_typical_free,
1808
+ /* .backend_init = */ nullptr,
1809
+ /* .backend_accept = */ nullptr,
1810
+ /* .backend_apply = */ nullptr,
1811
+ /* .backend_set_input = */ nullptr,
947
1812
  };
948
1813
 
949
1814
  struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1815
+ const bool is_empty = (p >= 1.0f);
1816
+
1817
+ if (is_empty) {
1818
+ return llama_sampler_init_empty("?typical");
1819
+ }
1820
+
950
1821
  return llama_sampler_init(
951
1822
  /* .iface = */ &llama_sampler_typical_i,
952
1823
  /* .ctx = */ new llama_sampler_typical {
@@ -958,12 +1829,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
958
1829
 
959
1830
  // temp
960
1831
 
961
- struct llama_sampler_temp {
1832
+ struct llama_sampler_temp : public llama_sampler_backend {
962
1833
  const float temp;
963
1834
  };
964
1835
 
965
- static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
966
- return "temp";
1836
+ static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
1837
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1838
+ return sctx->get_name();
967
1839
  }
968
1840
 
969
1841
  static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -981,19 +1853,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
981
1853
  delete (llama_sampler_temp *) smpl->ctx;
982
1854
  }
983
1855
 
1856
+ static void llama_sampler_backend_temp_sampling(
1857
+ struct ggml_context * ctx,
1858
+ struct ggml_cgraph * gf,
1859
+ struct llama_sampler_data * data,
1860
+ float temp) {
1861
+ if (temp <= 0.0f) {
1862
+ // Find the most probable token index.
1863
+ struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1864
+ ggml_set_name(max_idx, "temp_max_idx");
1865
+
1866
+ if (data->candidates) {
1867
+ struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1868
+ data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
1869
+ } else {
1870
+ data->candidates = max_idx;
1871
+ }
1872
+
1873
+ struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1874
+ data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
1875
+
1876
+ return;
1877
+ }
1878
+
1879
+ data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
1880
+
1881
+ GGML_UNUSED(gf);
1882
+ }
1883
+
1884
+ static bool llama_sampler_temp_backend_init(
1885
+ struct llama_sampler * smpl,
1886
+ ggml_backend_buffer_type_t buft) {
1887
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1888
+
1889
+ const bool res = llama_sampler_backend_support(smpl, buft);
1890
+
1891
+ sctx->init(res);
1892
+
1893
+ return res;
1894
+ }
1895
+
1896
+ static void llama_sampler_temp_backend_apply(
1897
+ struct llama_sampler * smpl,
1898
+ struct ggml_context * ctx,
1899
+ struct ggml_cgraph * gf,
1900
+ struct llama_sampler_data * data) {
1901
+ auto * sctx = (llama_sampler_temp *) smpl->ctx;
1902
+ llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
1903
+ }
1904
+
984
1905
  static struct llama_sampler_i llama_sampler_temp_i = {
985
- /* .name = */ llama_sampler_temp_name,
986
- /* .accept = */ nullptr,
987
- /* .apply = */ llama_sampler_temp_apply,
988
- /* .reset = */ nullptr,
989
- /* .clone = */ llama_sampler_temp_clone,
990
- /* .free = */ llama_sampler_temp_free,
1906
+ /* .name = */ llama_sampler_temp_name,
1907
+ /* .accept = */ nullptr,
1908
+ /* .apply = */ llama_sampler_temp_apply,
1909
+ /* .reset = */ nullptr,
1910
+ /* .clone = */ llama_sampler_temp_clone,
1911
+ /* .free = */ llama_sampler_temp_free,
1912
+ /* .backend_init = */ llama_sampler_temp_backend_init,
1913
+ /* .backend_accept = */ nullptr,
1914
+ /* .backend_apply = */ llama_sampler_temp_backend_apply,
1915
+ /* .backend_set_input = */ nullptr,
991
1916
  };
992
1917
 
993
1918
  struct llama_sampler * llama_sampler_init_temp(float temp) {
1919
+ const bool is_empty = temp == 1.0f;
1920
+
1921
+ if (is_empty) {
1922
+ return llama_sampler_init_empty("?temp");
1923
+ }
1924
+
994
1925
  return llama_sampler_init(
995
1926
  /* .iface = */ &llama_sampler_temp_i,
996
1927
  /* .ctx = */ new llama_sampler_temp {
1928
+ ("temp"),
997
1929
  /*.temp = */ temp,
998
1930
  }
999
1931
  );
@@ -1001,18 +1933,19 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
1001
1933
 
1002
1934
  // temp-ext
1003
1935
 
1004
- struct llama_sampler_temp_ext {
1936
+ struct llama_sampler_temp_ext : public llama_sampler_backend {
1005
1937
  const float temp;
1006
1938
  const float delta;
1007
1939
  const float exponent;
1008
1940
  };
1009
1941
 
1010
- static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1011
- return "temp-ext";
1942
+ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
1943
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
1944
+ return sctx->get_name();
1012
1945
  }
1013
1946
 
1014
1947
  static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1015
- const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1948
+ auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1016
1949
  if (ctx->delta > 0) {
1017
1950
  const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1018
1951
  const float max_temp = ctx->temp + ctx->delta;
@@ -1027,7 +1960,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
1027
1960
  // Calculate maximum possible entropy
1028
1961
  float max_entropy = -logf(1.0f / cur_p->size);
1029
1962
 
1030
- llama_sampler_softmax_impl(cur_p);
1963
+ llama_sampler_softmax_impl(cur_p, true);
1031
1964
 
1032
1965
  // Calculate entropy of the softmax probabilities
1033
1966
  float entropy = 0.0f;
@@ -1091,24 +2024,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1091
2024
  delete (llama_sampler_temp_ext *) smpl->ctx;
1092
2025
  }
1093
2026
 
2027
+ static bool llama_sampler_temp_ext_backend_init(
2028
+ struct llama_sampler * smpl,
2029
+ ggml_backend_buffer_type_t buft) {
2030
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2031
+
2032
+ const bool res = llama_sampler_backend_support(smpl, buft);
2033
+
2034
+ sctx->init(res);
2035
+
2036
+ return res;
2037
+ }
2038
+
2039
+ static void llama_sampler_temp_ext_backend_apply(
2040
+ struct llama_sampler * smpl,
2041
+ struct ggml_context * ctx,
2042
+ struct ggml_cgraph * gf,
2043
+ struct llama_sampler_data * data) {
2044
+ auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2045
+
2046
+ // Revert to standard temperature scaling if delta or temp are non-positive.
2047
+ if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
2048
+ llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
2049
+ return;
2050
+ }
2051
+
2052
+ // Calculate min_temp, max_temp, and max_entropy.
2053
+ const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
2054
+ const float max_temp = sctx->temp + sctx->delta;
2055
+ const float max_entropy = logf(data->logits->ne[0]);
2056
+
2057
+ // Calculate the probabilities.
2058
+ struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
2059
+ ggml_set_name(probs, "temp_ext_softmax_probs");
2060
+
2061
+ // Clamp probabilities to avoid log(0) which would give -inf
2062
+ struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
2063
+ ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
2064
+
2065
+ // Calculate the entropy, entropy = -Σ(p * log(p)).
2066
+ struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
2067
+ struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
2068
+ struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
2069
+ struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
2070
+ ggml_set_name(log_probs, "temp_ext_log_probs");
2071
+ ggml_set_name(p_log_p, "temp_ext_p_log_p");
2072
+ ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
2073
+ ggml_set_name(entropy, "temp_ext_entropy");
2074
+
2075
+ // Normalize the entropy, norm_entropy = entropy / max_entropy
2076
+ struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
2077
+ ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
2078
+
2079
+ // Calculate the dynamic temperature:
2080
+ // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
2081
+ //
2082
+ // Calculate powf(normalized_entropy, exponent) as
2083
+ // norm_entropy^exponent = exp(exponent * log(norm_entropy))
2084
+ struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
2085
+ struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
2086
+ struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
2087
+ // With pow_entropy computed we can now compute dyn_temp, scaling by
2088
+ // (max_temp - min_temp) and then adding min_temp.
2089
+ struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
2090
+ ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
2091
+ ggml_set_name(scaled_log, "temp_ext_scaled_log");
2092
+ ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
2093
+ ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
2094
+
2095
+ // Scale the logits by the dynamic temperature
2096
+ struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
2097
+ ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
2098
+
2099
+ data->logits = scaled_logits;
2100
+ }
2101
+
1094
2102
  static struct llama_sampler_i llama_sampler_temp_ext_i = {
1095
- /* .name = */ llama_sampler_temp_ext_name,
1096
- /* .accept = */ nullptr,
1097
- /* .apply = */ llama_sampler_temp_ext_apply,
1098
- /* .reset = */ nullptr,
1099
- /* .clone = */ llama_sampler_temp_ext_clone,
1100
- /* .free = */ llama_sampler_temp_ext_free,
2103
+ /* .name = */ llama_sampler_temp_ext_name,
2104
+ /* .accept = */ nullptr,
2105
+ /* .apply = */ llama_sampler_temp_ext_apply,
2106
+ /* .reset = */ nullptr,
2107
+ /* .clone = */ llama_sampler_temp_ext_clone,
2108
+ /* .free = */ llama_sampler_temp_ext_free,
2109
+ /* .backend_init = */ llama_sampler_temp_ext_backend_init,
2110
+ /* .backend_accept = */ nullptr,
2111
+ /* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
2112
+ /* .backend_set_input = */ nullptr,
1101
2113
  };
1102
2114
 
1103
2115
  struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1104
- return llama_sampler_init(
2116
+ const bool is_empty = temp == 1.0f && delta <= 0.0f;
2117
+
2118
+ if (is_empty) {
2119
+ return llama_sampler_init_empty("?temp-ext");
2120
+ }
2121
+
2122
+ auto * res = llama_sampler_init(
1105
2123
  /* .iface = */ &llama_sampler_temp_ext_i,
1106
2124
  /* .ctx = */ new llama_sampler_temp_ext {
2125
+ ("temp-ext"),
1107
2126
  /* .temp = */ temp,
1108
2127
  /* .delta = */ delta,
1109
2128
  /* .exponent = */ exponent,
1110
2129
  }
1111
2130
  );
2131
+
2132
+ return res;
1112
2133
  }
1113
2134
 
1114
2135
  // xtc
@@ -1139,17 +2160,20 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
1139
2160
 
1140
2161
  std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1141
2162
  float chance = distribution(ctx->rng);
1142
- if (chance > ctx->probability) return;
2163
+ if (chance > ctx->probability) {
2164
+ return;
2165
+ }
1143
2166
 
1144
- // in case it's not sorted/recalculated yet
1145
- llama_sampler_softmax_impl(cur_p);
2167
+ llama_sampler_softmax_impl(cur_p, true);
1146
2168
 
1147
2169
  int pos_last = 0;
1148
2170
 
1149
2171
  for (size_t i = 0; i < cur_p->size; ++i) {
1150
2172
  if (cur_p->data[i].p >= ctx->threshold) {
1151
2173
  pos_last = i;
1152
- } else break;
2174
+ } else {
2175
+ break;
2176
+ }
1153
2177
  }
1154
2178
 
1155
2179
  if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
@@ -1183,16 +2207,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1183
2207
  }
1184
2208
 
1185
2209
  static struct llama_sampler_i llama_sampler_xtc_i = {
1186
- /* .name = */ llama_sampler_xtc_name,
1187
- /* .accept = */ nullptr,
1188
- /* .apply = */ llama_sample_xtc_apply,
1189
- /* .reset = */ llama_sampler_xtc_reset,
1190
- /* .clone = */ llama_sampler_xtc_clone,
1191
- /* .free = */ llama_sampler_xtc_free,
2210
+ /* .name = */ llama_sampler_xtc_name,
2211
+ /* .accept = */ nullptr,
2212
+ /* .apply = */ llama_sample_xtc_apply,
2213
+ /* .reset = */ llama_sampler_xtc_reset,
2214
+ /* .clone = */ llama_sampler_xtc_clone,
2215
+ /* .free = */ llama_sampler_xtc_free,
2216
+ /* .backend_init = */ nullptr,
2217
+ /* .backend_accept = */ nullptr,
2218
+ /* .backend_apply = */ nullptr,
2219
+ /* .backend_set_input = */ nullptr,
1192
2220
  };
1193
2221
 
1194
2222
  struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1195
- auto seed_cur = get_rng_seed(seed);
2223
+ const bool is_empty = (p <= 0.0f || t > 0.5f);
2224
+
2225
+ if (is_empty) {
2226
+ return llama_sampler_init_empty("?xtc");
2227
+ }
2228
+
2229
+ const auto seed_cur = get_rng_seed(seed);
2230
+
1196
2231
  return llama_sampler_init(
1197
2232
  /* .iface = */ &llama_sampler_xtc_i,
1198
2233
  /* .ctx = */ new llama_sampler_xtc {
@@ -1221,7 +2256,7 @@ struct llama_sampler_mirostat {
1221
2256
 
1222
2257
  float mu;
1223
2258
 
1224
- std::mt19937 rng;
2259
+ std::mt19937 rng;
1225
2260
  };
1226
2261
 
1227
2262
  static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
@@ -1231,7 +2266,7 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s
1231
2266
  static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1232
2267
  auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1233
2268
 
1234
- llama_sampler_softmax_impl(cur_p);
2269
+ llama_sampler_softmax_impl(cur_p, true);
1235
2270
 
1236
2271
  // Estimate s_hat using the most probable m tokens
1237
2272
  float s_hat = 0.0;
@@ -1250,7 +2285,8 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke
1250
2285
  float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
1251
2286
 
1252
2287
  llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
1253
- llama_sampler_softmax_impl(cur_p);
2288
+
2289
+ llama_sampler_softmax_impl(cur_p, true);
1254
2290
 
1255
2291
  const int idx = llama_sample_dist(cur_p, ctx->rng);
1256
2292
 
@@ -1290,16 +2326,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1290
2326
  }
1291
2327
 
1292
2328
  static struct llama_sampler_i llama_sampler_mirostat_i = {
1293
- /* .name = */ llama_sampler_mirostat_name,
1294
- /* .accept = */ nullptr,
1295
- /* .apply = */ llama_sampler_mirostat_apply,
1296
- /* .reset = */ llama_sampler_mirostat_reset,
1297
- /* .clone = */ llama_sampler_mirostat_clone,
1298
- /* .free = */ llama_sampler_mirostat_free,
2329
+ /* .name = */ llama_sampler_mirostat_name,
2330
+ /* .accept = */ nullptr,
2331
+ /* .apply = */ llama_sampler_mirostat_apply,
2332
+ /* .reset = */ llama_sampler_mirostat_reset,
2333
+ /* .clone = */ llama_sampler_mirostat_clone,
2334
+ /* .free = */ llama_sampler_mirostat_free,
2335
+ /* .backend_init = */ nullptr,
2336
+ /* .backend_accept = */ nullptr,
2337
+ /* .backend_apply = */ nullptr,
2338
+ /* .backend_set_input = */ nullptr,
1299
2339
  };
1300
2340
 
1301
2341
  struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1302
- auto seed_cur = get_rng_seed(seed);
2342
+ const auto seed_cur = get_rng_seed(seed);
2343
+
1303
2344
  return llama_sampler_init(
1304
2345
  /* .iface = */ &llama_sampler_mirostat_i,
1305
2346
  /* .ctx = */ new llama_sampler_mirostat {
@@ -1336,7 +2377,7 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler *
1336
2377
  static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1337
2378
  auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1338
2379
 
1339
- llama_sampler_softmax_impl(cur_p);
2380
+ llama_sampler_softmax_impl(cur_p, true);
1340
2381
 
1341
2382
  // Truncate the words with surprise values greater than mu
1342
2383
  cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
@@ -1348,7 +2389,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
1348
2389
  }
1349
2390
 
1350
2391
  // Normalize the probabilities of the remaining words
1351
- llama_sampler_softmax_impl(cur_p);
2392
+ llama_sampler_softmax_impl(cur_p, true);
1352
2393
 
1353
2394
  const int idx = llama_sample_dist(cur_p, ctx->rng);
1354
2395
 
@@ -1389,12 +2430,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1389
2430
  }
1390
2431
 
1391
2432
  static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1392
- /* .name = */ llama_sampler_mirostat_v2_name,
1393
- /* .accept = */ nullptr,
1394
- /* .apply = */ llama_sampler_mirostat_v2_apply,
1395
- /* .reset = */ llama_sampler_mirostat_v2_reset,
1396
- /* .clone = */ llama_sampler_mirostat_v2_clone,
1397
- /* .free = */ llama_sampler_mirostat_v2_free,
2433
+ /* .name = */ llama_sampler_mirostat_v2_name,
2434
+ /* .accept = */ nullptr,
2435
+ /* .apply = */ llama_sampler_mirostat_v2_apply,
2436
+ /* .reset = */ llama_sampler_mirostat_v2_reset,
2437
+ /* .clone = */ llama_sampler_mirostat_v2_clone,
2438
+ /* .free = */ llama_sampler_mirostat_v2_free,
2439
+ /* .backend_init = */ nullptr,
2440
+ /* .backend_accept = */ nullptr,
2441
+ /* .backend_apply = */ nullptr,
2442
+ /* .backend_set_input = */ nullptr,
1398
2443
  };
1399
2444
 
1400
2445
  struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@@ -1506,12 +2551,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1506
2551
  }
1507
2552
 
1508
2553
  static struct llama_sampler_i llama_sampler_grammar_i = {
1509
- /* .name = */ llama_sampler_grammar_name,
1510
- /* .accept = */ llama_sampler_grammar_accept_impl,
1511
- /* .apply = */ llama_sampler_grammar_apply,
1512
- /* .reset = */ llama_sampler_grammar_reset,
1513
- /* .clone = */ llama_sampler_grammar_clone,
1514
- /* .free = */ llama_sampler_grammar_free,
2554
+ /* .name = */ llama_sampler_grammar_name,
2555
+ /* .accept = */ llama_sampler_grammar_accept_impl,
2556
+ /* .apply = */ llama_sampler_grammar_apply,
2557
+ /* .reset = */ llama_sampler_grammar_reset,
2558
+ /* .clone = */ llama_sampler_grammar_clone,
2559
+ /* .free = */ llama_sampler_grammar_free,
2560
+ /* .backend_init = */ nullptr,
2561
+ /* .backend_accept = */ nullptr,
2562
+ /* .backend_apply = */ nullptr,
2563
+ /* .backend_set_input = */ nullptr,
1515
2564
  };
1516
2565
 
1517
2566
  static struct llama_sampler * llama_sampler_init_grammar_impl(
@@ -1528,10 +2577,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1528
2577
  auto * ctx = new llama_sampler_grammar;
1529
2578
 
1530
2579
  if (grammar_str != nullptr && grammar_str[0] != '\0') {
2580
+ std::string trigger_pattern;
2581
+ llama_grammar * grammar = nullptr;
1531
2582
  // TODO: remove trigger_words support.
1532
2583
  if (trigger_words != nullptr && num_trigger_words > 0) {
1533
2584
  GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
1534
- std::string trigger_pattern("[\\s\\S]*?(");
2585
+ trigger_pattern = "[\\s\\S]*?(";
1535
2586
  for (size_t i = 0; i < num_trigger_words; ++i) {
1536
2587
  static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
1537
2588
  if (i > 0) {
@@ -1540,15 +2591,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1540
2591
  trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
1541
2592
  }
1542
2593
  trigger_pattern += ")[\\s\\S]*";
1543
- auto trigger_pattern_c = trigger_pattern.c_str();
1544
- trigger_patterns = &trigger_pattern_c;
1545
- num_trigger_patterns = 1;
2594
+
2595
+ std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
2596
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
2597
+ } else {
2598
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
1546
2599
  }
1547
2600
  *ctx = {
1548
2601
  /* .vocab = */ vocab,
1549
2602
  /* .grammar_str = */ grammar_str,
1550
2603
  /* .grammar_root = */ grammar_root,
1551
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
2604
+ /* .grammar = */ grammar,
1552
2605
  };
1553
2606
  if (!ctx->grammar) {
1554
2607
  delete ctx;
@@ -1709,12 +2762,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1709
2762
  }
1710
2763
 
1711
2764
  static struct llama_sampler_i llama_sampler_penalties_i = {
1712
- /* .name = */ llama_sampler_penalties_name,
1713
- /* .accept = */ llama_sampler_penalties_accept,
1714
- /* .apply = */ llama_sampler_penalties_apply,
1715
- /* .reset = */ llama_sampler_penalties_reset,
1716
- /* .clone = */ llama_sampler_penalties_clone,
1717
- /* .free = */ llama_sampler_penalties_free,
2765
+ /* .name = */ llama_sampler_penalties_name,
2766
+ /* .accept = */ llama_sampler_penalties_accept,
2767
+ /* .apply = */ llama_sampler_penalties_apply,
2768
+ /* .reset = */ llama_sampler_penalties_reset,
2769
+ /* .clone = */ llama_sampler_penalties_clone,
2770
+ /* .free = */ llama_sampler_penalties_free,
2771
+ /* .backend_init = */ nullptr,
2772
+ /* .backend_accept = */ nullptr,
2773
+ /* .backend_apply = */ nullptr,
2774
+ /* .backend_set_input = */ nullptr,
1718
2775
  };
1719
2776
 
1720
2777
  struct llama_sampler * llama_sampler_init_penalties(
@@ -1724,6 +2781,12 @@ struct llama_sampler * llama_sampler_init_penalties(
1724
2781
  float penalty_present) {
1725
2782
  penalty_last_n = std::max(penalty_last_n, 0);
1726
2783
 
2784
+ const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
2785
+
2786
+ if (is_empty) {
2787
+ return llama_sampler_init_empty("?penalties");
2788
+ }
2789
+
1727
2790
  return llama_sampler_init(
1728
2791
  /* .iface = */ &llama_sampler_penalties_i,
1729
2792
  /* .ctx = */ new llama_sampler_penalties {
@@ -1748,7 +2811,7 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
1748
2811
  }
1749
2812
 
1750
2813
  static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1751
- const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
2814
+ auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1752
2815
 
1753
2816
  if (ctx->n <= 0.0f || cur_p->size <= 1) {
1754
2817
  return;
@@ -1761,9 +2824,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
1761
2824
  for (size_t i = 0; i < cur_p->size; ++i) {
1762
2825
  // Only count non-negative infinity values
1763
2826
  if (cur_p->data[i].logit != -INFINITY) {
1764
- if (cur_p->data[i].logit > max) {
1765
- max = cur_p->data[i].logit;
1766
- }
2827
+ max = std::max(max, cur_p->data[i].logit);
1767
2828
  logits_sum += cur_p->data[i].logit;
1768
2829
  valid_count++;
1769
2830
  }
@@ -1780,13 +2841,14 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
1780
2841
  }
1781
2842
  float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
1782
2843
 
1783
- //apply mask
2844
+ // apply mask
1784
2845
  for (size_t i = 0; i < cur_p->size; ++i) {
1785
2846
  if (cur_p->data[i].logit < max - (ctx->n * std)) {
1786
2847
  cur_p->data[i].logit = -INFINITY;
1787
2848
  }
1788
2849
  }
1789
- llama_sampler_softmax_impl(cur_p);
2850
+
2851
+ llama_sampler_softmax_impl(cur_p, true);
1790
2852
  }
1791
2853
 
1792
2854
  static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
@@ -1799,15 +2861,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1799
2861
  }
1800
2862
 
1801
2863
  static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1802
- /* .name = */ llama_sampler_top_n_sigma_name,
1803
- /* .accept = */ nullptr,
1804
- /* .apply = */ llama_sampler_top_n_sigma_apply,
1805
- /* .reset = */ nullptr,
1806
- /* .clone = */ llama_sampler_top_n_sigma_clone,
1807
- /* .free = */ llama_sampler_top_n_sigma_free,
2864
+ /* .name = */ llama_sampler_top_n_sigma_name,
2865
+ /* .accept = */ nullptr,
2866
+ /* .apply = */ llama_sampler_top_n_sigma_apply,
2867
+ /* .reset = */ nullptr,
2868
+ /* .clone = */ llama_sampler_top_n_sigma_clone,
2869
+ /* .free = */ llama_sampler_top_n_sigma_free,
2870
+ /* .backend_init = */ nullptr,
2871
+ /* .backend_accept = */ nullptr,
2872
+ /* .backend_apply = */ nullptr,
2873
+ /* .backend_set_input = */ nullptr,
1808
2874
  };
1809
2875
 
1810
2876
  struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
2877
+ const bool is_empty = (n <= 0.0f);
2878
+
2879
+ if (is_empty) {
2880
+ return llama_sampler_init_empty("?top-n-sigma");
2881
+ }
2882
+
1811
2883
  return llama_sampler_init(
1812
2884
  /* .iface = */ &llama_sampler_top_n_sigma_i,
1813
2885
  /* .ctx = */ new llama_sampler_top_n_sigma {
@@ -1991,7 +3063,9 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
1991
3063
 
1992
3064
  {
1993
3065
  const int last = last_n_repeat - 1;
1994
- int rt = 0, lt = 0;
3066
+
3067
+ int rt = 0;
3068
+ int lt = 0;
1995
3069
 
1996
3070
  for (int k = 1; k < last_n_repeat; ++k) {
1997
3071
  if (k > rt) {
@@ -2127,22 +3201,30 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
2127
3201
  }
2128
3202
 
2129
3203
  static struct llama_sampler_i llama_sampler_dry_i = {
2130
- /* .name = */ llama_sampler_dry_name,
2131
- /* .accept = */ llama_sampler_dry_accept,
2132
- /* .apply = */ llama_sampler_dry_apply,
2133
- /* .reset = */ llama_sampler_dry_reset,
2134
- /* .clone = */ llama_sampler_dry_clone,
2135
- /* .free = */ llama_sampler_dry_free,
3204
+ /* .name = */ llama_sampler_dry_name,
3205
+ /* .accept = */ llama_sampler_dry_accept,
3206
+ /* .apply = */ llama_sampler_dry_apply,
3207
+ /* .reset = */ llama_sampler_dry_reset,
3208
+ /* .clone = */ llama_sampler_dry_clone,
3209
+ /* .free = */ llama_sampler_dry_free,
3210
+ /* .backend_init = */ nullptr,
3211
+ /* .backend_accept = */ nullptr,
3212
+ /* .backend_apply = */ nullptr,
3213
+ /* .backend_set_input = */ nullptr,
2136
3214
  };
2137
3215
 
2138
- struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
2139
- int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
3216
+ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
3217
+ int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0);
2140
3218
  std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
2141
3219
  const int MAX_CHAR_LEN = 40;
2142
3220
  const int MAX_SEQ_LEN = 20;
2143
3221
 
2144
3222
  const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2145
3223
 
3224
+ if (!dry_enabled) {
3225
+ return llama_sampler_init_empty("?dry");
3226
+ }
3227
+
2146
3228
  if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2147
3229
  // Process sequence breakers
2148
3230
  for (size_t i = 0; i < num_breakers; ++i) {
@@ -2169,7 +3251,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
2169
3251
  return llama_sampler_init(
2170
3252
  /* .iface = */ &llama_sampler_dry_i,
2171
3253
  /* .ctx = */ new llama_sampler_dry {
2172
- /* .total_context_size = */ context_size,
3254
+ /* .total_context_size = */ n_ctx_train,
2173
3255
  /* .dry_multiplier = */ dry_multiplier,
2174
3256
  /* .dry_base = */ dry_base,
2175
3257
  /* .dry_allowed_length = */ dry_allowed_length,
@@ -2213,16 +3295,23 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
2213
3295
 
2214
3296
  // logit-bias
2215
3297
 
2216
- struct llama_sampler_logit_bias {
3298
+ struct llama_sampler_logit_bias : public llama_sampler_backend {
2217
3299
  const int32_t n_vocab;
2218
3300
 
2219
3301
  const std::vector<llama_logit_bias> logit_bias;
2220
3302
 
2221
3303
  std::vector<llama_logit_bias> to_search;
3304
+
3305
+ struct ggml_tensor * inp_logit_bias;
3306
+ struct ggml_tensor * inp_logit_idxs;
3307
+
3308
+ ggml_context_ptr inp_ctx;
3309
+ ggml_backend_buffer_ptr inp_buf;
2222
3310
  };
2223
3311
 
2224
- static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
2225
- return "logit-bias";
3312
+ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
3313
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
3314
+ return ctx->get_name();
2226
3315
  }
2227
3316
 
2228
3317
  static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -2267,25 +3356,123 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2267
3356
  delete (llama_sampler_logit_bias *) smpl->ctx;
2268
3357
  }
2269
3358
 
3359
+ static void llama_sampler_logit_bias_backend_apply(
3360
+ struct llama_sampler * smpl,
3361
+ struct ggml_context * ctx,
3362
+ struct ggml_cgraph * gf,
3363
+ struct llama_sampler_data * data) {
3364
+ GGML_UNUSED(gf);
3365
+ GGML_UNUSED(ctx);
3366
+
3367
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3368
+ if (sctx->logit_bias.empty()) {
3369
+ return;
3370
+ }
3371
+
3372
+ ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
3373
+
3374
+ cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
3375
+ cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
3376
+ cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
3377
+
3378
+ data->logits = ggml_add(ctx, data->logits, cur);
3379
+ }
3380
+
3381
+ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
3382
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3383
+ if (sctx->logit_bias.empty()) {
3384
+ return;
3385
+ }
3386
+
3387
+ GGML_ASSERT(sctx->inp_logit_bias != nullptr);
3388
+ GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
3389
+
3390
+ const size_t n = sctx->logit_bias.size();
3391
+
3392
+ std::vector<float> data_logit_bias(n, 0.0f);
3393
+ std::vector<int32_t> data_logit_idxs(n, 0);
3394
+ for (size_t i = 0; i < n; ++i) {
3395
+ const auto & lb = sctx->logit_bias[i];
3396
+ GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
3397
+ data_logit_bias[i] = lb.bias;
3398
+ data_logit_idxs[i] = lb.token;
3399
+ }
3400
+
3401
+ ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
3402
+ ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
3403
+ }
3404
+
3405
+ static bool llama_sampler_logit_bias_backend_init(
3406
+ struct llama_sampler * smpl,
3407
+ ggml_backend_buffer_type_t buft) {
3408
+ auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3409
+
3410
+ sctx->init(true);
3411
+
3412
+ if (sctx->logit_bias.empty()) {
3413
+ return true;
3414
+ }
3415
+
3416
+ ggml_init_params params = {
3417
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
3418
+ /*.mem_buffer =*/ nullptr,
3419
+ /*.no_alloc =*/ true,
3420
+ };
3421
+
3422
+ sctx->inp_ctx.reset(ggml_init(params));
3423
+
3424
+ const size_t n = sctx->logit_bias.size();
3425
+
3426
+ sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
3427
+ ggml_set_name(sctx->inp_logit_bias, "logit_bias");
3428
+ ggml_set_input(sctx->inp_logit_bias);
3429
+
3430
+ sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
3431
+ ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
3432
+ ggml_set_input(sctx->inp_logit_idxs);
3433
+
3434
+ // Allocate all tensors from our context to the backend
3435
+ sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
3436
+
3437
+ ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
3438
+
3439
+ return true;
3440
+ }
3441
+
2270
3442
  static struct llama_sampler_i llama_sampler_logit_bias_i = {
2271
- /* .name = */ llama_sampler_logit_bias_name,
2272
- /* .accept = */ nullptr,
2273
- /* .apply = */ llama_sampler_logit_bias_apply,
2274
- /* .reset = */ nullptr,
2275
- /* .clone = */ llama_sampler_logit_bias_clone,
2276
- /* .free = */ llama_sampler_logit_bias_free,
3443
+ /* .name = */ llama_sampler_logit_bias_name,
3444
+ /* .accept = */ nullptr,
3445
+ /* .apply = */ llama_sampler_logit_bias_apply,
3446
+ /* .reset = */ nullptr,
3447
+ /* .clone = */ llama_sampler_logit_bias_clone,
3448
+ /* .free = */ llama_sampler_logit_bias_free,
3449
+ /* .backend_init = */ llama_sampler_logit_bias_backend_init,
3450
+ /* .backend_accept = */ nullptr,
3451
+ /* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
3452
+ /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
2277
3453
  };
2278
3454
 
2279
3455
  struct llama_sampler * llama_sampler_init_logit_bias(
2280
3456
  int32_t n_vocab,
2281
3457
  int32_t n_logit_bias,
2282
3458
  const llama_logit_bias * logit_bias) {
3459
+ const bool is_empty = n_logit_bias <= 0;
3460
+
3461
+ if (is_empty) {
3462
+ return llama_sampler_init_empty("?logit-bias");
3463
+ }
3464
+
2283
3465
  return llama_sampler_init(
2284
3466
  /* .iface = */ &llama_sampler_logit_bias_i,
2285
3467
  /* .ctx = */ new llama_sampler_logit_bias {
2286
- /* .n_vocab = */ n_vocab,
2287
- /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2288
- /* .to_search = */ {},
3468
+ ("logit-bias"),
3469
+ /* .n_vocab = */ n_vocab,
3470
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
3471
+ /* .to_search = */ {},
3472
+ /* .inp_logit_bias = */ nullptr,
3473
+ /* .inp_logit_idxs = */ nullptr,
3474
+ /* .inp_ctx = */ nullptr,
3475
+ /* .inp_buf = */ nullptr,
2289
3476
  }
2290
3477
  );
2291
3478
  }
@@ -2308,7 +3495,7 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp
2308
3495
  static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2309
3496
  auto * ctx = (llama_sampler_infill *) smpl->ctx;
2310
3497
 
2311
- llama_sampler_softmax_impl(cur_p);
3498
+ llama_sampler_softmax_impl(cur_p, true);
2312
3499
 
2313
3500
  #if defined(GGML_DEBUG_SAMPLER_INFILL)
2314
3501
  #define LOG_DBG_CUR LLAMA_LOG_DEBUG
@@ -2441,8 +3628,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2441
3628
  if (n_non_eog == 0) {
2442
3629
  cur_p->size = 1;
2443
3630
  cur_p->data[0].id = ctx->vocab->token_eot();
3631
+ if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
3632
+ cur_p->data[0].id = ctx->vocab->token_eos();
3633
+ }
2444
3634
  cur_p->data[0].logit = 1.0f;
2445
3635
 
3636
+ GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
3637
+
2446
3638
  return;
2447
3639
  }
2448
3640
 
@@ -2493,12 +3685,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2493
3685
  }
2494
3686
 
2495
3687
  static struct llama_sampler_i llama_sampler_infill_i = {
2496
- /* .name = */ llama_sampler_infill_name,
2497
- /* .accept = */ nullptr,
2498
- /* .apply = */ llama_sampler_infill_apply,
2499
- /* .reset = */ nullptr,
2500
- /* .clone = */ llama_sampler_infill_clone,
2501
- /* .free = */ llama_sampler_infill_free,
3688
+ /* .name = */ llama_sampler_infill_name,
3689
+ /* .accept = */ nullptr,
3690
+ /* .apply = */ llama_sampler_infill_apply,
3691
+ /* .reset = */ nullptr,
3692
+ /* .clone = */ llama_sampler_infill_clone,
3693
+ /* .free = */ llama_sampler_infill_free,
3694
+ /* .backend_apply = */ nullptr,
3695
+ /* .backend_accept = */ nullptr,
3696
+ /* .backend_set_input = */ nullptr,
3697
+ /* .backend_init = */ nullptr,
2502
3698
  };
2503
3699
 
2504
3700
  struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
@@ -2530,7 +3726,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2530
3726
  if (smpl->iface == &llama_sampler_chain_i) {
2531
3727
  const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2532
3728
  for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2533
- const uint32_t seed = llama_sampler_get_seed(*it);
3729
+ const uint32_t seed = llama_sampler_get_seed(it->ptr);
2534
3730
  if (seed != LLAMA_DEFAULT_SEED) {
2535
3731
  return seed;
2536
3732
  }
@@ -2560,8 +3756,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
2560
3756
  void llama_perf_sampler_print(const struct llama_sampler * chain) {
2561
3757
  const auto data = llama_perf_sampler(chain);
2562
3758
 
2563
- LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2564
- __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
3759
+ LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
2565
3760
  }
2566
3761
 
2567
3762
  void llama_perf_sampler_reset(struct llama_sampler * chain) {
@@ -2571,5 +3766,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
2571
3766
 
2572
3767
  auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2573
3768
 
2574
- ctx->t_sample_us = ctx->n_sample = 0;
3769
+ ctx->t_sample_us = 0;
3770
+ ctx->n_sample = 0;
2575
3771
  }