whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -2,10 +2,6 @@
2
2
 
3
3
  #define CUDA_CPY_BLOCK_SIZE 64
4
4
 
5
- void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
5
+ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
6
6
 
7
7
  void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8
-
9
- void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
10
-
11
- void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
@@ -0,0 +1,307 @@
1
+ #include <algorithm>
2
+ #include "cumsum.cuh"
3
+ #include "convert.cuh"
4
+ #include "ggml-cuda/common.cuh"
5
+ #include "ggml.h"
6
+
7
+ #ifdef GGML_CUDA_USE_CUB
8
+ # include <cub/cub.cuh>
9
+ #endif // GGML_CUDA_USE_CUB
10
+
11
+ template<typename T, int BLOCK_SIZE>
12
+ static __global__ void cumsum_cub_kernel(
13
+ const T * __restrict__ src,
14
+ T * __restrict__ dst,
15
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
16
+ const int64_t s01, const int64_t s02, const int64_t s03,
17
+ const int64_t s1, const int64_t s2, const int64_t s3) {
18
+ #ifdef GGML_CUDA_USE_CUB
19
+ using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
20
+
21
+ __shared__ typename BlockScanT::TempStorage temp_storage;
22
+ __shared__ T block_carry;
23
+
24
+ const int tid = threadIdx.x;
25
+ constexpr int UNROLL_FACTOR = 4;
26
+ constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
27
+
28
+ const int64_t i1 = blockIdx.x;
29
+ const int64_t i2 = blockIdx.y;
30
+ const int64_t i3 = blockIdx.z;
31
+
32
+ if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
33
+ return;
34
+ }
35
+
36
+ const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
37
+ T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
38
+
39
+ if (tid == 0) {
40
+ block_carry = 0;
41
+ }
42
+ __syncthreads();
43
+
44
+ for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
45
+ T items[UNROLL_FACTOR];
46
+ T thread_sum = T(0);
47
+
48
+ #pragma unroll
49
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
50
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
51
+ T val = (idx < ne00) ? src_row[idx] : T(0);
52
+ thread_sum += val;
53
+ items[i] = thread_sum;
54
+ }
55
+
56
+ // Block-wide scan on thread sums
57
+ T thread_prefix;
58
+ T block_total;
59
+ BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
60
+ __syncthreads();
61
+
62
+ // Add offset to each item and store
63
+ T thread_offset = thread_prefix - thread_sum + block_carry;
64
+ #pragma unroll
65
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
66
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
67
+ if (idx < ne00) {
68
+ dst_row[idx] = items[i] + thread_offset;
69
+ }
70
+ }
71
+
72
+ __syncthreads();
73
+
74
+ // Update carry for next tile
75
+ if (tid == 0) {
76
+ block_carry += block_total;
77
+ }
78
+ }
79
+ #else
80
+ NO_DEVICE_CODE;
81
+ #endif // GGML_CUDA_USE_CUB
82
+ }
83
+
84
+ // Fallback kernel implementation
85
+ template<typename T>
86
+ static __global__ void cumsum_kernel(
87
+ const T * src, T * dst,
88
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
89
+ const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
90
+ const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
91
+
92
+ GGML_UNUSED_VARS(s00, s0);
93
+
94
+ const int tid = threadIdx.x;
95
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
96
+ const int lane = tid % warp_size;
97
+ const int warp = tid / warp_size;
98
+ const int warps_per_block = blockDim.x / warp_size;
99
+
100
+ extern __shared__ float smem[];
101
+ float * s_vals = smem;
102
+ float * s_warp_sums = smem + blockDim.x;
103
+ float * s_carry = smem + blockDim.x + warps_per_block;
104
+ float * s_chunk_total = s_carry + 1;
105
+
106
+ // Initialize carry
107
+ if (tid == 0) {
108
+ *s_carry = 0.0f;
109
+ }
110
+ __syncthreads();
111
+
112
+ const int64_t i3 = blockIdx.z;
113
+ const int64_t i2 = blockIdx.y;
114
+ const int64_t i1 = blockIdx.x;
115
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
116
+ return;
117
+ }
118
+
119
+ const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
120
+ T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
121
+
122
+ // register blocking: process 4 elements per thread to hide latency
123
+ // and reduce synchronization overhead
124
+ constexpr int num_unroll = 4;
125
+ T temp[num_unroll];
126
+
127
+ for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
128
+ int64_t idx = i + tid * num_unroll;
129
+
130
+ // thread local sequential scan
131
+ temp[0] = (idx < ne00 ? src_row[idx] : T(0));
132
+ #pragma unroll
133
+ for (int64_t j = 1; j < num_unroll; j++) {
134
+ temp[j] = temp[j - 1];
135
+ if (idx + j < ne00) {
136
+ temp[j] += src_row[idx + j];
137
+ } else {
138
+ temp[j] += 0;
139
+ }
140
+ }
141
+
142
+ // last emenent is sum of all values assigned to thread
143
+ float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
144
+
145
+ // Warp inclusive scan
146
+ val = warp_prefix_inclusive_sum<T, warp_size>(val);
147
+ s_vals[tid] = val;
148
+
149
+ if (lane == warp_size - 1) {
150
+ s_warp_sums[warp] = val;
151
+ }
152
+ __syncthreads();
153
+
154
+ // Exclusive scan of warp sums (warp 0 only)
155
+ if (warp == 0) {
156
+ float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
157
+ float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
158
+ if (tid < warps_per_block) {
159
+ s_warp_sums[tid] = inc - w; // exclusive sum
160
+ }
161
+ if (tid == warps_per_block - 1) {
162
+ *s_chunk_total = inc; // total sum of this chunk
163
+ }
164
+ }
165
+ __syncthreads();
166
+
167
+ // write back results
168
+ float carry = *s_carry;
169
+ // calculate sum offset for this thread
170
+ float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
171
+
172
+ #pragma unroll
173
+ for (int32_t j = 0; j < num_unroll; j++) {
174
+ if (idx + j < ne00) {
175
+ dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
176
+ }
177
+ }
178
+
179
+ __syncthreads();
180
+
181
+ // Update carry for next chunk
182
+ if (tid == 0) {
183
+ *s_carry += *s_chunk_total;
184
+ }
185
+ }
186
+ }
187
+
188
+ #ifdef GGML_CUDA_USE_CUB
189
+ template <typename T>
190
+ static void cumsum_cub(ggml_cuda_pool & pool,
191
+ const T * src,
192
+ T * dst,
193
+ int64_t ne,
194
+ cudaStream_t stream) {
195
+ size_t tmp_size = 0;
196
+
197
+ // Query how much temp storage CUDA UnBound (CUB) needs
198
+ cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
199
+ tmp_size, // reference to size (will be set by CUB)
200
+ src, // input pointer
201
+ dst, // output pointer
202
+ ne, // number of elements
203
+ stream // CUDA stream to use
204
+ );
205
+
206
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
207
+
208
+ // Perform the inclusive scan
209
+ cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
210
+ }
211
+ #endif // GGML_CUDA_USE_CUB
212
+
213
+ template<typename T>
214
+ static void cumsum_cuda(
215
+ [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
216
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
217
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
218
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
219
+ cudaStream_t stream) {
220
+
221
+ const size_t type_size = sizeof(T);
222
+ bool use_cub = false;
223
+ #ifdef GGML_CUDA_USE_CUB
224
+ // Check if we can use CUB (data must be contiguous along innermost dimension)
225
+ const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
226
+
227
+ if (is_contiguous) {
228
+ use_cub = true;
229
+ const int64_t nrows = ne01 * ne02 * ne03;
230
+ // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
231
+ // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
232
+ if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
233
+ for (int i=0; i<nrows; i++) {
234
+ cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
235
+ }
236
+ return;
237
+ }
238
+ }
239
+ #endif // GGML_CUDA_USE_CUB
240
+ dim3 grid_dims(ne01, ne02, ne03);
241
+ const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
242
+ const int warp_size = info.warp_size;
243
+ const int num_warps = (ne00 + warp_size - 1) / warp_size;
244
+ int block_size = num_warps * warp_size;
245
+ block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
246
+ dim3 block_dims(block_size, 1, 1);
247
+ const int warps_per_block = block_size / warp_size;
248
+ const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
249
+
250
+ if (use_cub && ne00 >= 1024) {
251
+ cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
252
+ src, dst,
253
+ ne00, ne01, ne02, ne03,
254
+ nb01 / type_size, nb02 / type_size, nb03 / type_size,
255
+ nb1 / type_size, nb2 / type_size, nb3 / type_size
256
+ );
257
+ } else {
258
+ cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
259
+ src, dst,
260
+ ne00, ne01, ne02, ne03,
261
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
262
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
263
+ );
264
+ }
265
+ }
266
+
267
+ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268
+ const ggml_tensor * src0 = dst->src[0];
269
+ cudaStream_t stream = ctx.stream();
270
+
271
+ GGML_ASSERT(src0->type == dst->type);
272
+ switch(src0->type) {
273
+ case GGML_TYPE_F32:
274
+ {
275
+ cumsum_cuda(
276
+ ctx, (const float *)src0->data, (float *)dst->data,
277
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
278
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
279
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
280
+ stream
281
+ );
282
+ } break;
283
+ // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
284
+ /*case GGML_TYPE_F16:
285
+ {
286
+ cumsum_cuda(
287
+ (const half *)src0->data, (half *)dst->data,
288
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
289
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
290
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
291
+ stream
292
+ );
293
+ } break;
294
+ case GGML_TYPE_BF16:
295
+ {
296
+ cumsum_cuda(
297
+ (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
298
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
299
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
300
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
301
+ stream
302
+ );
303
+ } break;*/
304
+ default:
305
+ GGML_ABORT("fatal error");
306
+ }
307
+ }
@@ -0,0 +1,5 @@
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_CUMSUM_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -0,0 +1,77 @@
1
+ #include "convert.cuh"
2
+ #include "diag.cuh"
3
+ #include "ggml.h"
4
+
5
+ template <typename T>
6
+ static __global__ void diag_kernel(T * __restrict__ dst,
7
+ const T * __restrict__ src,
8
+ const int64_t ne0,
9
+ const int64_t ne1,
10
+ const int64_t ne2,
11
+ const int64_t ne3,
12
+ const int64_t total_elements) {
13
+ const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
14
+
15
+ if (global_idx >= total_elements) {
16
+ return;
17
+ }
18
+
19
+ const int64_t i0 = global_idx % ne0;
20
+ const int64_t i1 = (global_idx / ne0) % ne1;
21
+ const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;
22
+ const int64_t i3 = global_idx / (ne0 * ne1 * ne2);
23
+
24
+ const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
25
+
26
+ if (i0 == i1) {
27
+ const int64_t batch_idx = i3 * ne2 + i2;
28
+ const int64_t src_idx = batch_idx * ne0 + i0;
29
+ dst[dst_idx] = src[src_idx];
30
+ } else {
31
+ dst[dst_idx] = ggml_cuda_cast<T>(0);
32
+ }
33
+ GGML_UNUSED_VARS(ne3);
34
+ }
35
+
36
+ void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
37
+ const ggml_tensor * src0 = dst->src[0];
38
+
39
+ void * dst_d = dst->data;
40
+ const void * src0_d = src0->data;
41
+
42
+ cudaStream_t stream = ctx.stream();
43
+
44
+ GGML_ASSERT(ggml_is_contiguous(dst));
45
+ GGML_ASSERT(ggml_is_contiguous(src0));
46
+
47
+ const int64_t ne00 = src0->ne[0];
48
+ const int64_t ne01 = src0->ne[1];
49
+ const int64_t ne02 = src0->ne[2];
50
+ const int64_t ne03 = src0->ne[3];
51
+
52
+ const int64_t ne0 = dst->ne[0];
53
+ const int64_t ne1 = dst->ne[1];
54
+ const int64_t ne2 = dst->ne[2];
55
+ const int64_t ne3 = dst->ne[3];
56
+
57
+ GGML_ASSERT(ne00 == ne0);
58
+ GGML_ASSERT(ne01 == 1);
59
+ GGML_ASSERT(ne02 == ne2);
60
+ GGML_ASSERT(ne03 == ne3);
61
+
62
+ const int64_t n_elems = ggml_nelements(dst);
63
+ const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;
64
+
65
+ switch (dst->type) {
66
+ case GGML_TYPE_F32:
67
+ diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,
68
+ ne1, ne2, ne3, n_elems);
69
+ break;
70
+ case GGML_TYPE_F16:
71
+ diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,
72
+ ne1, ne2, ne3, n_elems);
73
+ break;
74
+ default:
75
+ GGML_ABORT("unsupported type");
76
+ }
77
+ }
@@ -0,0 +1,5 @@
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_DIAG_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -10,6 +10,14 @@
10
10
  #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
11
11
  #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
12
12
 
13
+ // log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
14
+ // by the VKQ accumulators is effectively being shifted up by a factor of 2.
15
+ // This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
16
+ // However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
17
+ // Still, the value range should be shifted as much as necessary but as little as possible.
18
+ // The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
19
+ #define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
20
+
13
21
  typedef void (* fattn_kernel_t)(
14
22
  const char * __restrict__ Q,
15
23
  const char * __restrict__ K,
@@ -25,7 +33,7 @@ typedef void (* fattn_kernel_t)(
25
33
  const float m1,
26
34
  const uint32_t n_head_log2,
27
35
  const float logit_softcap,
28
- const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
36
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
29
37
  const int32_t nb01, const int32_t nb02, const int32_t nb03,
30
38
  const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
31
39
  const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -55,11 +63,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
55
63
  ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
56
64
  #pragma unroll
57
65
  for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
58
- #ifdef FAST_FP16_AVAILABLE
66
+ #ifdef V_DOT2_F32_F16_AVAILABLE
59
67
  ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
60
68
  #else
61
69
  ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
62
- #endif // FP16_AVAILABLE
70
+ #endif // V_DOT2_F32_F16_AVAILABLE
63
71
  }
64
72
  }
65
73
 
@@ -621,7 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
621
629
  template<int D, int ncols1, int ncols2> // D == head size
622
630
  __launch_bounds__(D, 1)
623
631
  static __global__ void flash_attn_stream_k_fixup(
624
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
632
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
633
+ const int nbatch_fa) {
625
634
  constexpr int ncols = ncols1*ncols2;
626
635
 
627
636
  const int bidx0 = blockIdx.x;
@@ -632,11 +641,11 @@ static __global__ void flash_attn_stream_k_fixup(
632
641
 
633
642
  const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
634
643
 
635
- const int iter_k = ne11 / FATTN_KQ_STRIDE;
636
- const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
644
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
645
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
637
646
 
638
- const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
639
- const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
647
+ const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
648
+ const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
640
649
 
641
650
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
642
651
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -672,7 +681,7 @@ static __global__ void flash_attn_stream_k_fixup(
672
681
  int bidx = bidx0 - 1;
673
682
  int kbc_stop = kbc0;
674
683
  while(true) {
675
- const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
684
+ const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
676
685
  if (kbc == kbc_stop) { // Did not have any data.
677
686
  bidx--;
678
687
  kbc_stop = kbc;
@@ -765,7 +774,7 @@ static __global__ void flash_attn_combine_results(
765
774
  template <int DV, int ncols1, int ncols2>
766
775
  void launch_fattn(
767
776
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
768
- const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
777
+ const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769
778
  ) {
770
779
  constexpr int ncols = ncols1 * ncols2;
771
780
 
@@ -790,10 +799,6 @@ void launch_fattn(
790
799
  GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
791
800
 
792
801
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
793
- GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
794
- "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
795
-
796
- GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
797
802
 
798
803
  ggml_cuda_pool & pool = ctx.pool();
799
804
  cudaStream_t main_stream = ctx.stream();
@@ -878,7 +883,7 @@ void launch_fattn(
878
883
  // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
879
884
  // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
880
885
  // multiple sequences of possibly different lengths.
881
- if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
886
+ if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
882
887
  const int s31 = mask->nb[1] / sizeof(half2);
883
888
  const int s33 = mask->nb[3] / sizeof(half2);
884
889
 
@@ -897,6 +902,7 @@ void launch_fattn(
897
902
  const dim3 block_dim(warp_size, nwarps, 1);
898
903
  int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
899
904
  CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
905
+ GGML_ASSERT(max_blocks_per_sm > 0);
900
906
  int parallel_blocks = max_blocks_per_sm;
901
907
 
902
908
  dim3 blocks_num;
@@ -914,10 +920,11 @@ void launch_fattn(
914
920
  blocks_num.y = 1;
915
921
  blocks_num.z = 1;
916
922
 
917
- dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
923
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
924
+ dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
925
+ }
918
926
  } else {
919
- GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
920
- const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
927
+ const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
921
928
 
922
929
  // parallel_blocks must not be larger than what the tensor size allows:
923
930
  parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -946,7 +953,7 @@ void launch_fattn(
946
953
 
947
954
  blocks_num.x = ntiles_x;
948
955
  blocks_num.y = parallel_blocks;
949
- blocks_num.z = Q->ne[2]*Q->ne[3];
956
+ blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
950
957
 
951
958
  if (parallel_blocks > 1) {
952
959
  dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -972,6 +979,9 @@ void launch_fattn(
972
979
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
973
980
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
974
981
 
982
+ // TODO other tensor dimensions after removal of WMMA kernel:
983
+ const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
984
+
975
985
  GGML_ASSERT(block_dim.x % warp_size == 0);
976
986
  fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
977
987
  (const char *) Q->data,
@@ -982,7 +992,7 @@ void launch_fattn(
982
992
  KV_max.ptr,
983
993
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
984
994
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
985
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
995
+ Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
986
996
  K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
987
997
  nb21, nb22, nb23,
988
998
  mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
@@ -997,7 +1007,7 @@ void launch_fattn(
997
1007
 
998
1008
  flash_attn_stream_k_fixup<DV, ncols1, ncols2>
999
1009
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
1000
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
1010
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
1001
1011
  }
1002
1012
  } else if (parallel_blocks > 1) {
1003
1013
  const dim3 block_dim_combine(DV, 1, 1);