whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -26,6 +26,10 @@ struct llama_memory_breakdown_data {
26
26
  size_t model = 0; // memory allocated for the model
27
27
  size_t context = 0; // memory allocated for the context
28
28
  size_t compute = 0; // memory allocated for temporary compute buffers
29
+
30
+ size_t total() const {
31
+ return model + context + compute;
32
+ }
29
33
  };
30
34
 
31
35
  struct llama_context {
@@ -43,11 +47,11 @@ struct llama_context {
43
47
 
44
48
  ggml_backend_sched_t get_sched() const;
45
49
 
46
- uint32_t n_ctx() const;
47
- uint32_t n_ctx_per_seq() const;
48
- uint32_t n_batch() const;
49
- uint32_t n_ubatch() const;
50
- uint32_t n_seq_max() const;
50
+ uint32_t n_ctx() const;
51
+ uint32_t n_ctx_seq() const;
52
+ uint32_t n_batch() const;
53
+ uint32_t n_ubatch() const;
54
+ uint32_t n_seq_max() const;
51
55
 
52
56
  uint32_t n_threads() const;
53
57
  uint32_t n_threads_batch() const;
@@ -66,6 +70,18 @@ struct llama_context {
66
70
  float * get_embeddings_ith(int32_t i);
67
71
  float * get_embeddings_seq(llama_seq_id seq_id);
68
72
 
73
+ llama_token * get_sampled_tokens() const;
74
+ llama_token get_sampled_token_ith(int32_t idx);
75
+
76
+ float * get_sampled_logits_ith(int32_t idx);
77
+ size_t get_sampled_logits_count(int32_t idx);
78
+
79
+ float * get_sampled_probs_ith(int32_t idx);
80
+ size_t get_sampled_probs_count(int32_t idx);
81
+
82
+ const llama_token * get_sampled_candidates_ith(int32_t idx);
83
+ size_t get_sampled_candidates_count(int32_t idx);
84
+
69
85
  void attach_threadpool(
70
86
  ggml_threadpool_t threadpool,
71
87
  ggml_threadpool_t threadpool_batch);
@@ -188,16 +204,19 @@ private:
188
204
 
189
205
  // Make sure enough space is available for outputs.
190
206
  // Returns max number of outputs for which space was reserved.
191
- uint32_t output_reserve(int32_t n_outputs);
207
+ uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
192
208
 
193
209
  void output_reorder();
194
210
 
211
+ // map the output row index `i` to batch index
212
+ int64_t output_resolve_row(int32_t i) const;
213
+
195
214
  //
196
215
  // graph
197
216
  //
198
217
 
199
218
  public:
200
- uint32_t graph_max_nodes() const;
219
+ uint32_t graph_max_nodes(uint32_t n_tokens) const;
201
220
 
202
221
  // can reuse the llm_graph_result instance of the context (for example to update a memory module)
203
222
  llm_graph_result * get_gf_res_reserve() const;
@@ -206,7 +225,10 @@ public:
206
225
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
207
226
 
208
227
  // reserve a graph with a dummy ubatch of the specified size
209
- ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
228
+ ggml_cgraph * graph_reserve(
229
+ uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
230
+
231
+ bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler);
210
232
 
211
233
  private:
212
234
  llm_graph_params graph_params(
@@ -247,6 +269,31 @@ private:
247
269
  size_t embd_size = 0; // capacity (of floats) for embeddings
248
270
  float * embd = nullptr;
249
271
 
272
+ // TODO: simplify
273
+ struct sampling_info {
274
+ std::map<llama_seq_id, llama_sampler *> samplers;
275
+
276
+ float * logits = nullptr;
277
+ size_t logits_size = 0;
278
+
279
+ llama_token * sampled = nullptr;
280
+ size_t sampled_size = 0;
281
+
282
+ float * probs = nullptr;
283
+ size_t probs_size = 0;
284
+
285
+ llama_token * candidates = nullptr;
286
+ size_t candidates_size = 0;
287
+
288
+ std::vector<uint32_t> logits_count;
289
+ std::vector<uint32_t> probs_count;
290
+ std::vector<uint32_t> candidates_count;
291
+
292
+ std::vector<llama_token> token_ids_full_vocab;
293
+ };
294
+
295
+ sampling_info sampling;
296
+
250
297
  // sequence embeddings output (map of [n_embd] vectors)
251
298
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
252
299
  std::map<llama_seq_id, std::vector<float>> embd_seq;
@@ -281,9 +328,10 @@ private:
281
328
 
282
329
  std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
283
330
 
284
- // buffer types used for the compute buffer of each backend
331
+ // pointers and buffer types used for the compute buffer of each backend
285
332
  std::vector<ggml_backend_t> backend_ptrs;
286
333
  std::vector<ggml_backend_buffer_type_t> backend_buft;
334
+ std::vector<size_t> backend_buf_exp_size; // expected buffer sizes
287
335
 
288
336
  llm_graph_result_ptr gf_res_prev;
289
337
  llm_graph_result_ptr gf_res_reserve;
@@ -8,6 +8,7 @@
8
8
 
9
9
  struct llama_cparams {
10
10
  uint32_t n_ctx; // context size used during inference
11
+ uint32_t n_ctx_seq; // context for a single sequence
11
12
  uint32_t n_batch;
12
13
  uint32_t n_ubatch;
13
14
  uint32_t n_seq_max;
@@ -6,8 +6,10 @@
6
6
 
7
7
  #include <cmath>
8
8
  #include <algorithm>
9
+ #include <cstdint>
9
10
  #include <stdexcept>
10
11
 
12
+ #define MAX_REPETITION_THRESHOLD 2000
11
13
  //
12
14
  // helpers
13
15
  //
@@ -179,6 +181,52 @@ static std::pair<uint32_t, const char *> parse_char(const char * src) {
179
181
  throw std::runtime_error("unexpected end of input");
180
182
  }
181
183
 
184
+ static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
185
+ const char * pos = src;
186
+ if (*pos != '<') {
187
+ throw std::runtime_error(std::string("expecting '<' at ") + pos);
188
+ }
189
+ pos++;
190
+
191
+ // Parse <[id]>
192
+ if (*pos == '[') {
193
+ pos++;
194
+ const char * int_end = parse_int(pos);
195
+ uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
196
+ pos = int_end;
197
+ if (*pos != ']') {
198
+ throw std::runtime_error(std::string("expecting ']' at ") + pos);
199
+ }
200
+ pos++;
201
+ if (*pos != '>') {
202
+ throw std::runtime_error(std::string("expecting '>' at ") + pos);
203
+ }
204
+ pos++;
205
+ return std::make_pair(token_id, pos);
206
+ }
207
+
208
+ if (vocab == nullptr) {
209
+ throw std::runtime_error(std::string("no vocab to parse token at ") + src);
210
+ }
211
+
212
+ // Parse <token> and tokenize to obtain the token id
213
+ while (*pos != 0 && *pos != '>') {
214
+ pos++;
215
+ }
216
+ if (*pos != '>') {
217
+ throw std::runtime_error(std::string("expecting '>' at ") + pos);
218
+ }
219
+ pos++;
220
+
221
+ llama_token tokens[2];
222
+ int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
223
+ if (n_tokens != 1) {
224
+ // must tokenize to exactly 1 token
225
+ throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
226
+ }
227
+ return std::make_pair(tokens[0], pos);
228
+ }
229
+
182
230
  static void print_grammar_char(FILE * file, uint32_t c) {
183
231
  if (0x20 <= c && c <= 0x7f) {
184
232
  fprintf(file, "%c", static_cast<char>(c));
@@ -210,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
210
258
  case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
211
259
  case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
212
260
  case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
261
+ case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
262
+ case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
213
263
  }
214
264
  switch (elem.type) {
215
265
  case LLAMA_GRETYPE_END:
@@ -226,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
226
276
  print_grammar_char(file, elem.value);
227
277
  fprintf(file, "\") ");
228
278
  break;
279
+ case LLAMA_GRETYPE_TOKEN:
280
+ fprintf(file, "<[");
281
+ fprintf(file, "%u", elem.value);
282
+ fprintf(file, "]> ");
283
+ break;
284
+ case LLAMA_GRETYPE_TOKEN_NOT:
285
+ fprintf(file, "!");
286
+ fprintf(file, "<[");
287
+ fprintf(file, "%u", elem.value);
288
+ fprintf(file, "]> ");
289
+ break;
229
290
  }
230
291
  }
231
292
  fprintf(file, "\n");
@@ -282,6 +343,17 @@ static void print_rule(
282
343
  case LLAMA_GRETYPE_CHAR_ANY:
283
344
  fprintf(file, ".");
284
345
  break;
346
+ case LLAMA_GRETYPE_TOKEN:
347
+ fprintf(file, "<[");
348
+ fprintf(file, "%u", elem.value);
349
+ fprintf(file, "]> ");
350
+ break;
351
+ case LLAMA_GRETYPE_TOKEN_NOT:
352
+ fprintf(file, "!");
353
+ fprintf(file, "<[");
354
+ fprintf(file, "%u", elem.value);
355
+ fprintf(file, "]> ");
356
+ break;
285
357
  }
286
358
  if (is_char_element(elem)) {
287
359
  switch (rule[i + 1].type) {
@@ -297,6 +369,44 @@ static void print_rule(
297
369
  fprintf(file, "\n");
298
370
  }
299
371
 
372
+ //
373
+ // Regex utilities
374
+ //
375
+
376
+ size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
377
+ auto find_start_pos = [](const std::smatch & match) {
378
+ // get from the first matched capturing group to the end of the string
379
+ size_t start = std::string::npos;
380
+ for (auto i = 1u; i < match.size(); i++) {
381
+ if (match.length(i) > 0) {
382
+ start = match.position(i);
383
+ break;
384
+ }
385
+ }
386
+ if (start == std::string::npos) {
387
+ start = match.position(0);
388
+ }
389
+ return start;
390
+ };
391
+
392
+ if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
393
+ // match against the entire input
394
+ std::smatch match;
395
+ if (std::regex_match(input, match, regex)) {
396
+ return find_start_pos(match);
397
+ }
398
+ }
399
+
400
+ // search anywhere
401
+ std::smatch match;
402
+ if (std::regex_search(input, match, regex)) {
403
+ return find_start_pos(match);
404
+ }
405
+
406
+ return std::string::npos;
407
+ }
408
+
409
+
300
410
  //
301
411
  // implementation
302
412
  //
@@ -345,8 +455,10 @@ const char * llama_grammar_parser::parse_sequence(
345
455
  size_t last_sym_start = rule.size();
346
456
  const char * pos = src;
347
457
 
348
- auto handle_repetitions = [&](int min_times, int max_times) {
349
-
458
+ // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
459
+ // (though it's technically the same as -1 now)
460
+ auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
461
+ bool no_max = max_times == UINT64_MAX;
350
462
  if (last_sym_start == rule.size()) {
351
463
  throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
352
464
  }
@@ -373,20 +485,20 @@ const char * llama_grammar_parser::parse_sequence(
373
485
  rule.resize(last_sym_start);
374
486
  } else {
375
487
  // Repeat the previous elements (min_times - 1) times
376
- for (int i = 1; i < min_times; i++) {
488
+ for (uint64_t i = 1; i < min_times; i++) {
377
489
  rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
378
490
  }
379
491
  }
380
492
 
381
493
  uint32_t last_rec_rule_id = 0;
382
- auto n_opt = max_times < 0 ? 1 : max_times - min_times;
494
+ auto n_opt = no_max ? 1 : max_times - min_times;
383
495
 
384
496
  llama_grammar_rule rec_rule(prev_rule);
385
- for (int i = 0; i < n_opt; i++) {
497
+ for (uint64_t i = 0; i < n_opt; i++) {
386
498
  rec_rule.resize(prev_rule.size());
387
499
  uint32_t rec_rule_id = generate_symbol_id( rule_name);
388
- if (i > 0 || max_times < 0) {
389
- rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
500
+ if (i > 0 || no_max) {
501
+ rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
390
502
  }
391
503
  rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
392
504
  rec_rule.push_back({LLAMA_GRETYPE_END, 0});
@@ -440,6 +552,17 @@ const char * llama_grammar_parser::parse_sequence(
440
552
  }
441
553
  }
442
554
  pos = parse_space(pos + 1, is_nested);
555
+ } else if (*pos == '<' || *pos == '!') { // token
556
+ auto type = LLAMA_GRETYPE_TOKEN;
557
+ if (*pos == '!') { // token inverse
558
+ type = LLAMA_GRETYPE_TOKEN_NOT;
559
+ pos++;
560
+ }
561
+ auto token_pair = parse_token(vocab, pos);
562
+ const char * token_end = token_pair.second;
563
+ last_sym_start = rule.size();
564
+ rule.push_back({type, token_pair.first});
565
+ pos = parse_space(token_end, is_nested);
443
566
  } else if (is_word_char(*pos)) { // rule reference
444
567
  const char * name_end = parse_name(pos);
445
568
  uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
@@ -478,10 +601,10 @@ const char * llama_grammar_parser::parse_sequence(
478
601
  throw std::runtime_error(std::string("expecting an int at ") + pos);
479
602
  }
480
603
  const char * int_end = parse_int(pos);
481
- int min_times = std::stoul(std::string(pos, int_end - pos));
604
+ uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
482
605
  pos = parse_space(int_end, is_nested);
483
606
 
484
- int max_times = -1;
607
+ uint64_t max_times = UINT64_MAX; // default: no max limit
485
608
 
486
609
  if (*pos == '}') {
487
610
  max_times = min_times;
@@ -502,6 +625,10 @@ const char * llama_grammar_parser::parse_sequence(
502
625
  } else {
503
626
  throw std::runtime_error(std::string("expecting ',' at ") + pos);
504
627
  }
628
+ bool has_max = max_times != UINT64_MAX;
629
+ if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
630
+ throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
631
+ }
505
632
  handle_repetitions(min_times, max_times);
506
633
  } else {
507
634
  break;
@@ -683,6 +810,21 @@ static bool llama_grammar_match_partial_char(
683
810
  return !is_positive_char;
684
811
  }
685
812
 
813
+ // returns true iff token matches the rule at pos (regular or inverse)
814
+ // asserts that pos is pointing to a token element
815
+ static bool llama_grammar_match_token(
816
+ const llama_grammar_element * pos,
817
+ const llama_token token) {
818
+ GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
819
+ if (pos->type == LLAMA_GRETYPE_TOKEN) {
820
+ return pos->value == static_cast<uint32_t>(token);
821
+ }
822
+ if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
823
+ return pos->value != static_cast<uint32_t>(token);
824
+ }
825
+ return false;
826
+ }
827
+
686
828
  // transforms a grammar pushdown stack into N possible stacks, all ending
687
829
  // at a character range (terminal element)
688
830
  static void llama_grammar_advance_stack(
@@ -730,6 +872,8 @@ static void llama_grammar_advance_stack(
730
872
  case LLAMA_GRETYPE_CHAR:
731
873
  case LLAMA_GRETYPE_CHAR_NOT:
732
874
  case LLAMA_GRETYPE_CHAR_ANY:
875
+ case LLAMA_GRETYPE_TOKEN:
876
+ case LLAMA_GRETYPE_TOKEN_NOT:
733
877
  if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
734
878
  // only add the stack if it's not a duplicate of one we already have
735
879
  new_stacks.emplace_back(stack);
@@ -823,26 +967,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
823
967
  return grammar->stacks;
824
968
  }
825
969
 
970
+ static void llama_grammar_accept_chr(
971
+ struct llama_grammar & grammar,
972
+ const llama_grammar_stack & stack,
973
+ uint32_t chr,
974
+ llama_grammar_stacks & new_stacks) {
975
+ if (stack.empty()) {
976
+ return;
977
+ }
978
+
979
+ const llama_grammar_element * pos = stack.back();
980
+
981
+ // ignore if this turns into a token
982
+ if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
983
+ return;
984
+ }
985
+
986
+ auto match = llama_grammar_match_char(pos, chr);
987
+ if (match.first) {
988
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
989
+ if (!llama_grammar_is_end_of_sequence(match.second)) {
990
+ new_stack.push_back(match.second);
991
+ }
992
+ llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
993
+ }
994
+ }
995
+
826
996
  void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
827
997
  llama_grammar_stacks stacks_new;
828
998
  stacks_new.reserve(grammar->stacks.size());
829
999
 
830
1000
  for (const auto & stack : grammar->stacks) {
831
- if (stack.empty()) {
832
- continue;
833
- }
834
-
835
- auto match = llama_grammar_match_char(stack.back(), chr);
836
- if (match.first) {
837
- const llama_grammar_element * pos = match.second;
838
-
839
- // update top of stack to next element, if any
840
- llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
841
- if (!llama_grammar_is_end_of_sequence(pos)) {
842
- new_stack.push_back(pos);
843
- }
844
- llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
845
- }
1001
+ llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
846
1002
  }
847
1003
 
848
1004
  grammar->stacks = std::move(stacks_new);
@@ -867,6 +1023,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
867
1023
 
868
1024
  const llama_grammar_element * stack_pos = stack.back();
869
1025
 
1026
+ // if the top of the stack is a token rule, then we only need to check the token id
1027
+ if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1028
+ for (const auto & tok : candidates) {
1029
+ if (*tok.code_points == 0) {
1030
+ // reached the end of a token consumed by char rules, reject iff it ended
1031
+ // in a partial response
1032
+ if (tok.partial_utf8.n_remain != 0) {
1033
+ rejects.push_back(tok);
1034
+ }
1035
+ } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
1036
+ rejects.push_back(tok);
1037
+ }
1038
+ }
1039
+ return rejects;
1040
+ }
1041
+
870
1042
  llama_grammar_candidates next_candidates;
871
1043
  next_candidates.reserve(candidates.size());
872
1044
 
@@ -879,7 +1051,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
879
1051
  rejects.push_back(tok);
880
1052
  }
881
1053
  } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
882
- next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
1054
+ next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
883
1055
  } else {
884
1056
  rejects.push_back(tok);
885
1057
  }
@@ -897,7 +1069,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
897
1069
 
898
1070
  auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
899
1071
  for (const auto & tok : next_rejects) {
900
- rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
1072
+ rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
901
1073
  }
902
1074
 
903
1075
  return rejects;
@@ -964,12 +1136,13 @@ struct llama_grammar * llama_grammar_init_impl(
964
1136
  vocab,
965
1137
  std::move(vec_rules),
966
1138
  std::move(stacks),
967
- /* .partial_utf8 = */ {},
968
- /* .lazy =*/ false,
969
- /* .awaiting_trigger = */ false,
970
- /* .trigger_buffer = */ "",
971
- /* .trigger_tokens = */ {},
972
- /* .trigger_patterns = */ {},
1139
+ /* .partial_utf8 = */ {},
1140
+ /* .lazy = */ false,
1141
+ /* .awaiting_trigger = */ false,
1142
+ /* .trigger_buffer = */ "",
1143
+ /* .trigger_buffer_positions = */ {},
1144
+ /* .trigger_tokens = */ {},
1145
+ /* .trigger_patterns = */ {},
973
1146
  };
974
1147
  }
975
1148
 
@@ -982,7 +1155,7 @@ struct llama_grammar * llama_grammar_init_impl(
982
1155
  size_t num_trigger_patterns,
983
1156
  const llama_token * trigger_tokens,
984
1157
  size_t num_trigger_tokens) {
985
- llama_grammar_parser parser;
1158
+ llama_grammar_parser parser(vocab);
986
1159
 
987
1160
  // if there is a grammar, parse it
988
1161
  // rules will be empty (default) if there are parse errors
@@ -1069,10 +1242,11 @@ struct llama_grammar * llama_grammar_init_impl(
1069
1242
  vocab,
1070
1243
  std::move(vec_rules),
1071
1244
  std::move(stacks),
1072
- /* .partial_utf8 = */ {},
1073
- /* .lazy = */ lazy,
1074
- /* .awaiting_trigger = */ lazy,
1075
- /* .trigger_buffer = */ "",
1245
+ /* .partial_utf8 = */ {},
1246
+ /* .lazy = */ lazy,
1247
+ /* .awaiting_trigger = */ lazy,
1248
+ /* .trigger_buffer = */ "",
1249
+ /* .trigger_buffer_positions = */ {},
1076
1250
  std::move(vec_trigger_tokens),
1077
1251
  std::move(vec_trigger_patterns),
1078
1252
  };
@@ -1095,6 +1269,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
1095
1269
  grammar.lazy,
1096
1270
  grammar.awaiting_trigger,
1097
1271
  grammar.trigger_buffer,
1272
+ grammar.trigger_buffer_positions,
1098
1273
  grammar.trigger_tokens,
1099
1274
  grammar.trigger_patterns,
1100
1275
  };
@@ -1148,7 +1323,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
1148
1323
  cur_p->data[i].logit = -INFINITY;
1149
1324
  } else {
1150
1325
  candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
1151
- candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
1326
+ candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
1152
1327
  }
1153
1328
  }
1154
1329
 
@@ -1167,31 +1342,35 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1167
1342
  if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
1168
1343
  grammar.awaiting_trigger = false;
1169
1344
  grammar.trigger_buffer.clear();
1170
- llama_grammar_accept_str(grammar, piece);
1345
+ llama_grammar_accept_token(grammar, token, piece);
1171
1346
  LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1172
1347
  return;
1173
1348
  } else {
1349
+ auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
1350
+ grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
1174
1351
  grammar.trigger_buffer += piece;
1175
1352
 
1176
- std::smatch match;
1177
1353
  for (const auto & trigger_pattern : grammar.trigger_patterns) {
1178
- if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
1354
+ auto start = trigger_pattern.find(grammar.trigger_buffer);
1355
+ if (start != std::string::npos) {
1179
1356
  grammar.awaiting_trigger = false;
1180
- // get from the first matched capturing group to the end of the string
1181
- size_t start = std::string::npos;
1182
- for (auto i = 1u; i < match.size(); i++) {
1183
- if (match.length(i) > 0) {
1184
- start = match.position(i);
1185
- break;
1357
+
1358
+ // replay tokens that overlap with [start, end)
1359
+ for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
1360
+ auto [tok_start, tok_end] = tok_pos;
1361
+ if (tok_end <= start) {
1362
+ continue;
1186
1363
  }
1364
+
1365
+ size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
1366
+ size_t piece_len = tok_end - piece_start;
1367
+ auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
1368
+ llama_grammar_accept_token(grammar, tok, tok_piece);
1187
1369
  }
1188
- if (start == std::string::npos) {
1189
- start = match.position(0);
1190
- }
1370
+
1191
1371
  auto constrained_str = grammar.trigger_buffer.substr(start);
1192
- // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
1193
1372
  grammar.trigger_buffer.clear();
1194
- llama_grammar_accept_str(grammar, constrained_str);
1373
+ grammar.trigger_buffer_positions.clear();
1195
1374
  LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
1196
1375
  return;
1197
1376
  }
@@ -1210,7 +1389,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1210
1389
  GGML_ABORT("fatal error");
1211
1390
  }
1212
1391
 
1213
- llama_grammar_accept_str(grammar, piece);
1392
+ llama_grammar_accept_token(grammar, token, piece);
1214
1393
  }
1215
1394
 
1216
1395
  void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
@@ -1227,3 +1406,59 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
1227
1406
  throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
1228
1407
  }
1229
1408
  }
1409
+
1410
+ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
1411
+ // Note terminating 0 in decoded string
1412
+ const auto decoded = decode_utf8(piece, grammar.partial_utf8);
1413
+ const auto & code_points = decoded.first;
1414
+
1415
+ llama_grammar_stacks stacks_new;
1416
+ stacks_new.reserve(grammar.stacks.size());
1417
+
1418
+ for (const auto & stack : grammar.stacks) {
1419
+ if (stack.empty()) {
1420
+ continue;
1421
+ }
1422
+
1423
+ const llama_grammar_element * pos = stack.back();
1424
+
1425
+ if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
1426
+ if (llama_grammar_match_token(pos, token)) {
1427
+ llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
1428
+ if (!llama_grammar_is_end_of_sequence(pos + 1)) {
1429
+ new_stack.push_back(pos + 1);
1430
+ }
1431
+ llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
1432
+ }
1433
+ } else {
1434
+ llama_grammar_stacks current_stacks = {stack};
1435
+
1436
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1437
+ llama_grammar_stacks next_stacks;
1438
+
1439
+ for (const auto & cur_stack : current_stacks) {
1440
+ llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
1441
+ }
1442
+
1443
+ current_stacks = std::move(next_stacks);
1444
+ if (current_stacks.empty()) {
1445
+ break;
1446
+ }
1447
+ }
1448
+
1449
+ for (auto & surviving_stack : current_stacks) {
1450
+ if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
1451
+ stacks_new.emplace_back(surviving_stack);
1452
+ }
1453
+ }
1454
+ }
1455
+ }
1456
+
1457
+ grammar.stacks = std::move(stacks_new);
1458
+ grammar.partial_utf8 = decoded.second;
1459
+
1460
+ if (grammar.stacks.empty()) {
1461
+ throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
1462
+ }
1463
+ }
1464
+