whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -10,11 +10,21 @@
10
10
  #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
11
11
  #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
12
12
 
13
+ // log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
14
+ // by the VKQ accumulators is effectively being shifted up by a factor of 2.
15
+ // This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
16
+ // However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
17
+ // Still, the value range should be shifted as much as necessary but as little as possible.
18
+ // The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
19
+ #define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
20
+
13
21
  typedef void (* fattn_kernel_t)(
14
22
  const char * __restrict__ Q,
15
23
  const char * __restrict__ K,
16
24
  const char * __restrict__ V,
17
25
  const char * __restrict__ mask,
26
+ const char * __restrict__ sinks,
27
+ const int * __restrict__ KV_max,
18
28
  float * __restrict__ dst,
19
29
  float2 * __restrict__ dst_meta,
20
30
  const float scale,
@@ -23,300 +33,238 @@ typedef void (* fattn_kernel_t)(
23
33
  const float m1,
24
34
  const uint32_t n_head_log2,
25
35
  const float logit_softcap,
26
- const int ne00,
27
- const int ne01,
28
- const int ne02,
29
- const int ne03,
30
- const int ne10,
31
- const int ne11,
32
- const int ne12,
33
- const int ne13,
34
- const int ne31,
35
- const int nb31,
36
- const int nb01,
37
- const int nb02,
38
- const int nb03,
39
- const int nb11,
40
- const int nb12,
41
- const int nb13,
42
- const int nb21,
43
- const int nb22,
44
- const int nb23,
45
- const int ne0,
46
- const int ne1,
47
- const int ne2,
48
- const int ne3);
49
-
50
- typedef half (*vec_dot_KQ_f16_t)(
51
- const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
52
- typedef float (*vec_dot_KQ_f32_t)(
36
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
37
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
38
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
39
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
40
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
41
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
42
+ const int32_t nb31, const int32_t nb32, const int64_t nb33);
43
+
44
+ typedef float (*vec_dot_KQ_t)(
53
45
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
54
46
 
55
- template<typename T, int D, int warp_size>
56
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
47
+ template <int D, int nthreads>
48
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
49
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
50
+
51
+ const half2 * K_h2 = (const half2 *) K_c;
52
+ GGML_UNUSED(Q_q8);
53
+ GGML_UNUSED(Q_ds_v);
54
+
55
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
56
+ constexpr int cpy_ne = cpy_nb / 4;
57
+
58
+ float sum = 0.0f;
59
+
60
+ #pragma unroll
61
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
62
+ half2 tmp[cpy_ne];
63
+ ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
64
+ #pragma unroll
65
+ for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
66
+ #ifdef V_DOT2_F32_F16_AVAILABLE
67
+ ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
68
+ #else
69
+ ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
70
+ #endif // V_DOT2_F32_F16_AVAILABLE
71
+ }
72
+ }
73
+
74
+ return sum;
75
+ }
76
+
77
+ template<int D, int nthreads>
78
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
57
79
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
58
80
 
59
81
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60
82
  GGML_UNUSED(Q_v);
61
83
 
62
- T sum = 0.0f;
84
+ float sum = 0.0f;
63
85
 
64
86
  #pragma unroll
65
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
66
- const int k_KQ = k_KQ_0 + threadIdx.x;
87
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
88
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
67
89
 
68
90
  const int ib = k_KQ / QI8_1;
69
91
  const int iqs4 = k_KQ % QI4_0;
70
92
  const int shift = k_KQ & (QI8_1/2);
71
93
 
72
- const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
73
- const int u = Q_q8[k_KQ_0/warp_size];
94
+ int v;
95
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
96
+ v = (v >> shift) & 0x0F0F0F0F;
97
+ const int u = Q_q8[k_KQ_0/nthreads];
74
98
 
75
99
  const int sumi = ggml_cuda_dp4a(v, u, 0);
76
100
 
77
- #ifdef FP16_AVAILABLE
78
- if (std::is_same<T, half>::value) {
79
- const half2 * Q_ds = (const half2 *) Q_ds_v;
80
-
81
- const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
82
- sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
83
- } else
84
- #endif // FP16_AVAILABLE
85
- {
86
- const float2 * Q_ds = (const float2 *) Q_ds_v;
87
-
88
- sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
89
- }
101
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
102
+ sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
90
103
  }
91
104
 
92
105
  return sum;
93
106
  }
94
107
 
95
- template<typename T, int D, int warp_size>
96
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
108
+ template<int D, int nthreads>
109
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
97
110
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
98
111
 
99
112
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
100
113
  GGML_UNUSED(Q_v);
101
114
 
102
- T sum = 0.0f;
115
+ float sum = 0.0f;
103
116
 
104
117
  #pragma unroll
105
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
106
- const int k_KQ = k_KQ_0 + threadIdx.x;
118
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
119
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
107
120
 
108
121
  const int ib = k_KQ / QI8_1;
109
122
  const int iqs4 = k_KQ % QI4_1;
110
123
  const int shift = k_KQ & (QI8_1/2);
111
124
 
112
- const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
113
- const int u = Q_q8[k_KQ_0/warp_size];
125
+ int v;
126
+ ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
127
+ v = (v >> shift) & 0x0F0F0F0F;
128
+ const int u = Q_q8[k_KQ_0/nthreads];
114
129
 
115
130
  const int sumi = ggml_cuda_dp4a(v, u, 0);
116
131
 
117
- #ifdef FP16_AVAILABLE
118
- if (std::is_same<T, half>::value) {
119
- const half2 * Q_ds = (const half2 *) Q_ds_v;
120
-
121
- const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
122
- const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
123
- sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
124
- } else
125
- #endif // FP16_AVAILABLE
126
- {
127
- const float2 * Q_ds = (const float2 *) Q_ds_v;
128
-
129
- const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
130
- const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
132
+ const float2 K_dm = __half22float2(K_q4_1[ib].dm);
133
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
131
134
 
132
- sum += (T) (sumid4d8 + m4s8scaled);
133
- }
135
+ sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
134
136
  }
135
137
 
136
138
  return sum;
137
139
  }
138
140
 
139
- template<typename T, int D, int warp_size>
140
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141
+ template<int D, int nthreads>
142
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
141
143
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
142
144
 
143
145
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
144
146
  GGML_UNUSED(Q_v);
145
147
 
146
- T sum = 0.0f;
148
+ float sum = 0.0f;
147
149
 
148
150
  #pragma unroll
149
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
150
- const int k_KQ = k_KQ_0 + threadIdx.x;
151
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
152
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
151
153
 
152
154
  const int ib = k_KQ / QI8_1;
153
155
  const int iqs4 = k_KQ % QI5_0;
154
156
  const int iqs8 = k_KQ % QI8_1;
155
157
  const int shift = k_KQ & (QI8_1/2);
156
158
 
157
- int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
158
- const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
159
- v |= (vh << 4) & 0x00000010; // 0 -> 4
160
- v |= (vh << 11) & 0x00001000; // 1 -> 12
161
- v |= (vh << 18) & 0x00100000; // 2 -> 20
162
- v |= (vh << 25) & 0x10000000; // 3 -> 28
159
+ int v;
160
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
161
+ v = (v >> shift) & 0x0F0F0F0F;
163
162
 
164
- const int u = Q_q8[k_KQ_0/warp_size];
163
+ {
164
+ int vh;
165
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
166
+ vh >>= iqs8 * QI5_0;
167
+
168
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
169
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
170
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
171
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
172
+ }
165
173
 
166
- const int sumi = ggml_cuda_dp4a(v, u, 0);
174
+ const int u = Q_q8[k_KQ_0/nthreads];
167
175
 
168
- #ifdef FP16_AVAILABLE
169
- if (std::is_same<T, half>::value) {
170
- const half2 * Q_ds = (const half2 *) Q_ds_v;
176
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
171
177
 
172
- const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
173
- sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
174
- } else
175
- #endif // FP16_AVAILABLE
176
- {
177
- const float2 * Q_ds = (const float2 *) Q_ds_v;
178
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
178
179
 
179
- sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
180
- }
180
+ sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
181
181
  }
182
182
 
183
183
  return sum;
184
184
  }
185
185
 
186
- template<typename T, int D, int warp_size>
187
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
186
+ template<int D, int nthreads>
187
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
188
188
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189
189
 
190
190
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
191
191
  GGML_UNUSED(Q_v);
192
192
 
193
- T sum = 0.0f;
193
+ float sum = 0.0f;
194
194
 
195
195
  #pragma unroll
196
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
197
- const int k_KQ = k_KQ_0 + threadIdx.x;
196
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
197
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
198
198
 
199
199
  const int ib = k_KQ / QI8_1;
200
200
  const int iqs4 = k_KQ % QI5_1;
201
201
  const int iqs8 = k_KQ % QI8_1;
202
202
  const int shift = k_KQ & (QI8_1/2);
203
203
 
204
- int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
205
- const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
206
- v |= (vh << 4) & 0x00000010; // 0 -> 4
207
- v |= (vh << 11) & 0x00001000; // 1 -> 12
208
- v |= (vh << 18) & 0x00100000; // 2 -> 20
209
- v |= (vh << 25) & 0x10000000; // 3 -> 28
210
-
211
- const int u = Q_q8[k_KQ_0/warp_size];
204
+ int v;
205
+ ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
206
+ v = (v >> shift) & 0x0F0F0F0F;
212
207
 
213
- const int sumi = ggml_cuda_dp4a(v, u, 0);
208
+ {
209
+ int vh;
210
+ ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
211
+ vh >>= iqs8 * QI5_0;
212
+
213
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
214
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
215
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
216
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
217
+ }
214
218
 
215
- #ifdef FP16_AVAILABLE
216
- if (std::is_same<T, half>::value) {
217
- const half2 * Q_ds = (const half2 *) Q_ds_v;
219
+ const int u = Q_q8[k_KQ_0/nthreads];
218
220
 
219
- const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
220
- const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
221
- sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
222
- } else
223
- #endif // FP16_AVAILABLE
224
- {
225
- const float2 * Q_ds = (const float2 *) Q_ds_v;
221
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
226
222
 
227
- const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
228
- const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
223
+ const float2 K_dm = __half22float2(K_q5_1[ib].dm);
224
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
229
225
 
230
- sum += (T) (sumid5d8 + m5s8scaled);
231
- }
226
+ sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
232
227
  }
233
228
 
234
229
  return sum;
235
230
  }
236
231
 
237
- template <typename T, int D, int warp_size>
238
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
232
+ template <int D, int nthreads>
233
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
239
234
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
240
235
 
241
236
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
242
237
  GGML_UNUSED(Q_v);
243
238
 
244
- T sum = 0.0f;
239
+ float sum = 0.0f;
245
240
 
246
241
  #pragma unroll
247
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
248
- const int k_KQ = k_KQ_0 + threadIdx.x;
242
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
243
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
249
244
 
250
245
  const int ib = k_KQ / QI8_0;
251
246
  const int iqs = k_KQ % QI8_0;
252
247
 
253
- const int v = get_int_b2(K_q8_0[ib].qs, iqs);
254
-
255
- T Q_d;
256
- if (std::is_same<T, half>::value) {
257
- const half2 * Q_ds = (const half2 *) Q_ds_v;
258
- Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
259
- } else {
260
- const float2 * Q_ds = (const float2 *) Q_ds_v;
261
- Q_d = Q_ds[k_KQ_0/warp_size].x;
262
- }
263
-
264
- sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
265
- }
266
-
267
- return sum;
268
- }
269
-
270
- template <typename T, int D, int warp_size>
271
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272
- const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
273
-
274
- const half2 * K_h2 = (const half2 *) K_c;
275
- GGML_UNUSED(Q_q8);
276
- GGML_UNUSED(Q_ds_v);
277
-
278
- #ifdef FP16_AVAILABLE
279
- if (std::is_same<T, half>::value) {
280
- const half2 * Q_h2 = (const half2 *) Q_v;
281
-
282
- half2 sum2 = make_half2(0.0f, 0.0f);
283
-
284
- #pragma unroll
285
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
286
- const int k_KQ = k_KQ_0 + threadIdx.x;
287
-
288
- const half2 K_ik = K_h2[k_KQ];
289
- sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
290
- }
291
-
292
- return __low2half(sum2) + __high2half(sum2);
293
- }
294
- #endif // FP16_AVAILABLE
248
+ int v;
249
+ ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
295
250
 
296
- const float2 * Q_f2 = (const float2 *) Q_v;
297
-
298
- float sum = 0.0f;
251
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
252
+ const float Q_d = Q_ds[k_KQ_0/nthreads].x;
299
253
 
300
- #pragma unroll
301
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
302
- const int k_KQ = k_KQ_0 + threadIdx.x;
303
-
304
- const half2 K_ik = K_h2[k_KQ];
305
- sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
306
- sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
254
+ sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
307
255
  }
308
256
 
309
257
  return sum;
310
258
  }
311
259
 
312
- template <typename Tds>
260
+ template <typename Tds, int ni>
313
261
  static __device__ __forceinline__ void quantize_q8_1_to_shared(
314
262
  const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
315
263
 
316
264
  float vals[sizeof(int)] = {0.0f};
317
265
  #pragma unroll
318
266
  for (int l = 0; l < int(sizeof(int)); ++l) {
319
- vals[l] = scale * x[4*threadIdx.x + l];
267
+ vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
320
268
  }
321
269
 
322
270
  float amax = fabsf(vals[0]);
@@ -344,7 +292,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
344
292
  }
345
293
 
346
294
  yq32[threadIdx.x] = q32;
347
- if (threadIdx.x % QI8_1 == 0) {
295
+ if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
348
296
  if (std::is_same<Tds, half2>::value) {
349
297
  ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
350
298
  } else {
@@ -353,173 +301,336 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
353
301
  }
354
302
  }
355
303
 
356
- typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
357
- typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
304
+ typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
305
+
306
+ template <typename T, int ne>
307
+ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
308
+ if constexpr (std::is_same_v<T, half>) {
309
+ ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
310
+ } else if constexpr (std::is_same_v<T, float>) {
311
+ static_assert(ne % 2 == 0, "bad ne");
312
+ half2 tmp[ne/2];
313
+ ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
314
+ float2 * dst_f2 = (float2 *) dst;
315
+ #pragma unroll
316
+ for (int l = 0; l < ne/2; ++l) {
317
+ dst_f2[l] = __half22float2(tmp[l]);
318
+ }
319
+ } else {
320
+ static_assert(std::is_same_v<T, void>, "unsupported type");
321
+ }
322
+ }
358
323
 
359
- template <typename T>
360
- static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
324
+ template <typename T, int ne>
325
+ static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
361
326
  const block_q4_0 * x = (const block_q4_0 *) vx;
362
327
 
363
- const int64_t ib = i / QK4_0;
364
- const int iqs = i % (QK4_0/2);
365
- const int shift = (i % QK4_0) / (QK4_0/2);
328
+ const int64_t ib = i0 / QK4_0;
329
+ const int iqs = i0 % (QK4_0/2);
330
+ const int shift = (i0 % QK4_0) / (QK4_0/2);
331
+
332
+ int q;
333
+ static_assert(ne == 2 || ne == 4, "bad ne");
334
+ ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
335
+ q >>= 4*shift;
336
+ q &= 0x0F0F0F0F;
337
+ q = __vsubss4(q, 0x08080808);
366
338
 
367
- const T d = x[ib].d;
368
- const int q0 = x[ib].qs[iqs];
369
- const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
339
+ const int8_t * q8 = (const int8_t *) &q;
370
340
 
371
341
  #ifdef FP16_AVAILABLE
372
- if (std::is_same<T, half>::value) {
373
- return ((half) d)*((half) q);
374
- }
342
+ if constexpr (std::is_same_v<T, half>) {
343
+ const half2 d = __half2half2(x[ib].d);
344
+
345
+ #pragma unroll
346
+ for (int l0 = 0; l0 < ne; l0 += 2) {
347
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
348
+ }
349
+ } else
375
350
  #endif // FP16_AVAILABLE
351
+ if constexpr (std::is_same_v<T, float>) {
352
+ const float d = x[ib].d;
376
353
 
377
- return ((float) d)*((float) q);
354
+ #pragma unroll
355
+ for (int l = 0; l < ne; ++l) {
356
+ ((float *) dst)[l] = d * q8[l];
357
+ }
358
+ } else {
359
+ static_assert(std::is_same_v<T, void>, "bad type");
360
+ }
378
361
  }
379
362
 
380
- template <typename T>
381
- static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
363
+ template <typename T, int ne>
364
+ static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
382
365
  const block_q4_1 * x = (const block_q4_1 *) vx;
383
366
 
384
- const int64_t ib = i / QK4_1;
385
- const int iqs = i % (QK4_1/2);
386
- const int shift = (i % QK4_1) / (QK4_1/2);
367
+ const int64_t ib = i0 / QK4_1;
368
+ const int iqs = i0 % (QK4_1/2);
369
+ const int shift = (i0 % QK4_1) / (QK4_1/2);
370
+
371
+ int q;
372
+ static_assert(ne == 2 || ne == 4, "bad ne");
373
+ ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
374
+ q >>= 4*shift;
375
+ q &= 0x0F0F0F0F;
387
376
 
388
- const half2 dm = x[ib].dm;
389
- const int q0 = x[ib].qs[iqs];
390
- const int q = ((q0 >> (4*shift)) & 0x0F);
377
+ const int8_t * q8 = (const int8_t *) &q;
391
378
 
392
379
  #ifdef FP16_AVAILABLE
393
- if (std::is_same<T, half>::value) {
394
- return __low2half(dm)*((half) q) + __high2half(dm);
395
- }
380
+ if constexpr (std::is_same_v<T, half>) {
381
+ const half2 dm = x[ib].dm;
382
+ const half2 d = __half2half2( __low2half(dm));
383
+ const half2 m = __half2half2(__high2half(dm));
384
+
385
+ #pragma unroll
386
+ for (int l0 = 0; l0 < ne; l0 += 2) {
387
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
388
+ }
389
+ } else
396
390
  #endif // FP16_AVAILABLE
391
+ if constexpr (std::is_same_v<T, float>) {
392
+ const float2 dm = __half22float2(x[ib].dm);
397
393
 
398
- return __low2float(dm)*((float) q) + __high2float(dm);
394
+ #pragma unroll
395
+ for (int l = 0; l < ne; ++l) {
396
+ ((float *) dst)[l] = dm.x * q8[l] + dm.y;
397
+ }
398
+ } else {
399
+ static_assert(std::is_same_v<T, void>, "bad type");
400
+ }
399
401
  }
400
402
 
401
- template <typename T>
402
- static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
403
+ template <typename T, int ne>
404
+ static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
403
405
  const block_q5_0 * x = (const block_q5_0 *) vx;
404
406
 
405
- const int64_t ib = i / QK5_0;
406
- const int idq = i % QK5_0;
407
- const int iqs = i % (QK5_0/2);
408
- const int shift = (i % QK5_0) / (QK5_0/2);
407
+ const int64_t ib = i0 / QK5_0;
408
+ const int idq = i0 % QK5_0;
409
+ const int iqs = i0 % (QK5_0/2);
410
+ const int shift = (i0 % QK5_0) / (QK5_0/2);
409
411
 
410
- const T d = x[ib].d;
411
- const int ql0 = x[ib].qs[iqs];
412
- const int qh0 = get_int_b2(x[ib].qh, 0);
413
- const int ql = ((ql0 >> (4*shift)) & 0x0F);
414
- const int qh = ((qh0 >> idq) << 4) & 0x10;
415
- const int q = (ql | qh) - 16;
412
+ int q;
413
+ static_assert(ne == 2 || ne == 4, "bad ne");
414
+ ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
415
+ q >>= 4*shift;
416
+ q &= 0x0F0F0F0F;
416
417
 
417
- #ifdef FP16_AVAILABLE
418
- if (std::is_same<T, half>::value) {
419
- return ((half) d)*((half) q);
418
+ {
419
+ int qh;
420
+ ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
421
+ #pragma unroll
422
+ for (int l = 0; l < ne; ++l) {
423
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
424
+ }
420
425
  }
426
+
427
+ q = __vsubss4(q, 0x10101010);
428
+
429
+ const int8_t * q8 = (const int8_t *) &q;
430
+
431
+ #ifdef FP16_AVAILABLE
432
+ if constexpr (std::is_same_v<T, half>) {
433
+ const half2 d = __half2half2(x[ib].d);
434
+
435
+ #pragma unroll
436
+ for (int l0 = 0; l0 < ne; l0 += 2) {
437
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
438
+ }
439
+ } else
421
440
  #endif // FP16_AVAILABLE
441
+ if constexpr (std::is_same_v<T, float>) {
442
+ const float d = x[ib].d;
422
443
 
423
- return ((float) d)*((float) q);
444
+ #pragma unroll
445
+ for (int l = 0; l < ne; ++l) {
446
+ ((float *) dst)[l] = d * q8[l];
447
+ }
448
+ } else {
449
+ static_assert(std::is_same_v<T, void>, "bad type");
450
+ }
424
451
  }
425
452
 
426
- template <typename T>
427
- static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
453
+ template <typename T, int ne>
454
+ static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
428
455
  const block_q5_1 * x = (const block_q5_1 *) vx;
429
456
 
430
- const int64_t ib = i / QK5_1;
431
- const int idq = i % QK5_1;
432
- const int iqs = i % (QK5_1/2);
433
- const int shift = (i % QK5_1) / (QK5_1/2);
457
+ const int64_t ib = i0 / QK5_1;
458
+ const int idq = i0 % QK5_1;
459
+ const int iqs = i0 % (QK5_1/2);
460
+ const int shift = (i0 % QK5_1) / (QK5_1/2);
434
461
 
435
- const half2 dm = x[ib].dm;
436
- const int ql0 = x[ib].qs[iqs];
437
- const int qh0 = get_int_b4(x[ib].qh, 0);
438
- const int ql = ((ql0 >> (4*shift)) & 0x0F);
439
- const int qh = ((qh0 >> idq) << 4) & 0x10;
440
- const int q = (ql | qh);
462
+ int q;
463
+ static_assert(ne == 2 || ne == 4, "bad ne");
464
+ ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
465
+ q >>= 4*shift;
466
+ q &= 0x0F0F0F0F;
441
467
 
442
- #ifdef FP16_AVAILABLE
443
- if (std::is_same<T, half>::value) {
444
- return __low2half(dm)*((half) q) + __high2half(dm);
468
+ {
469
+ int qh;
470
+ ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
471
+ #pragma unroll
472
+ for (int l = 0; l < ne; ++l) {
473
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
474
+ }
445
475
  }
476
+
477
+ const int8_t * q8 = (const int8_t *) &q;
478
+
479
+ #ifdef FP16_AVAILABLE
480
+ if constexpr (std::is_same_v<T, half>) {
481
+ const half2 dm = x[ib].dm;
482
+ const half2 d = __half2half2( __low2half(dm));
483
+ const half2 m = __half2half2(__high2half(dm));
484
+
485
+ #pragma unroll
486
+ for (int l0 = 0; l0 < ne; l0 += 2) {
487
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
488
+ }
489
+ } else
446
490
  #endif // FP16_AVAILABLE
491
+ if constexpr (std::is_same_v<T, float>) {
492
+ const float2 dm = __half22float2(x[ib].dm);
447
493
 
448
- return __low2float(dm)*((float) q) + __high2float(dm);
494
+ #pragma unroll
495
+ for (int l = 0; l < ne; ++l) {
496
+ ((float *) dst)[l] = dm.x * q8[l] + dm.y;
497
+ }
498
+ } else {
499
+ static_assert(std::is_same_v<T, void>, "bad type");
500
+ }
449
501
  }
450
502
 
451
- template <typename T>
452
- static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
503
+ template <typename T, int ne>
504
+ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
453
505
  const block_q8_0 * x = (const block_q8_0 *) vx;
454
506
 
455
- const int64_t ib = i / QK8_0;
456
- const int iqs = i % QK8_0;
507
+ const int64_t ib = i0 / QK8_0;
508
+ const int iqs = i0 % QK8_0;
457
509
 
458
- const T d = x[ib].d;
459
- const int q = x[ib].qs[iqs];
510
+ static_assert(ne % 2 == 0, "bad ne");
511
+ int8_t qs[ne];
512
+ ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
460
513
 
461
514
  #ifdef FP16_AVAILABLE
462
- if (std::is_same<T, half>::value) {
463
- return ((half) d)*((half) q);
464
- }
515
+ if constexpr (std::is_same<T, half>::value) {
516
+ const half2 d = __half2half2(x[ib].d);
517
+
518
+ #pragma unroll
519
+ for (int l0 = 0; l0 < ne; l0 += 2) {
520
+ ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
521
+ }
522
+ } else
465
523
  #endif // FP16_AVAILABLE
524
+ if constexpr (std::is_same<T, float>::value) {
525
+ const float d = x[ib].d;
466
526
 
467
- return ((float) d)*((float) q);
527
+ #pragma unroll
528
+ for (int l = 0; l < ne; ++l) {
529
+ ((float *) dst)[l] = d * qs[l];
530
+ }
531
+ } else {
532
+ static_assert(std::is_same_v<T, void>, "unsupported type");
533
+ }
468
534
  }
469
535
 
470
- template <typename T>
471
- static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
472
- const half * x = (const half *) vx;
473
-
474
- return x[i];
536
+ template <ggml_type type_K, int D, int nthreads>
537
+ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
538
+ if constexpr (type_K == GGML_TYPE_F16) {
539
+ return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
540
+ } else if constexpr (type_K == GGML_TYPE_Q4_0) {
541
+ return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
542
+ } else if constexpr (type_K == GGML_TYPE_Q4_1) {
543
+ return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
544
+ } else if constexpr (type_K == GGML_TYPE_Q5_0) {
545
+ return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
546
+ } else if constexpr (type_K == GGML_TYPE_Q5_1) {
547
+ return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
548
+ } else if constexpr (type_K == GGML_TYPE_Q8_0) {
549
+ return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
550
+ } else {
551
+ static_assert(type_K == -1, "bad type");
552
+ return nullptr;
553
+ }
475
554
  }
476
555
 
477
- template <int D, int warp_size = WARP_SIZE>
478
- constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
479
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
480
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
481
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
482
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
483
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
484
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
485
- nullptr;
556
+ template <ggml_type type_V, typename T, int ne>
557
+ constexpr __device__ dequantize_V_t get_dequantize_V() {
558
+ if constexpr (type_V == GGML_TYPE_F16) {
559
+ return dequantize_V_f16<T, ne>;
560
+ } else if constexpr (type_V == GGML_TYPE_Q4_0) {
561
+ return dequantize_V_q4_0<T, ne>;
562
+ } else if constexpr (type_V == GGML_TYPE_Q4_1) {
563
+ return dequantize_V_q4_1<T, ne>;
564
+ } else if constexpr (type_V == GGML_TYPE_Q5_0) {
565
+ return dequantize_V_q5_0<T, ne>;
566
+ } else if constexpr (type_V == GGML_TYPE_Q5_1) {
567
+ return dequantize_V_q5_1<T, ne>;
568
+ } else if constexpr (type_V == GGML_TYPE_Q8_0) {
569
+ return dequantize_V_q8_0<T, ne>;
570
+ } else {
571
+ static_assert(type_V == -1, "bad type");
572
+ return nullptr;
573
+ }
486
574
  }
487
575
 
488
- template <int D, int warp_size = WARP_SIZE>
489
- constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
490
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
491
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
492
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
493
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
494
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
495
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
496
- nullptr;
497
- }
576
+ template <int ncols1>
577
+ __launch_bounds__(FATTN_KQ_STRIDE/2, 1)
578
+ static __global__ void flash_attn_mask_to_KV_max(
579
+ const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
580
+ const int ne31 = gridDim.x;
581
+ const int tid = threadIdx.x;
582
+ const int sequence = blockIdx.y;
583
+ const int jt = blockIdx.x;
498
584
 
499
- constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
500
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
501
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
502
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
503
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
504
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
505
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
506
- nullptr;
507
- }
585
+ mask += sequence*s33 + jt*ncols1*s31;
586
+
587
+ __shared__ int buf_iw[WARP_SIZE];
588
+ if (tid < WARP_SIZE) {
589
+ buf_iw[tid] = 1;
590
+ }
591
+ __syncthreads();
592
+
593
+ int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
594
+ for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
595
+ int all_inf = 1;
596
+
597
+ #pragma unroll
598
+ for (int j = 0; j < ncols1; ++j) {
599
+ const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
600
+ all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
601
+ }
602
+
603
+ all_inf = warp_reduce_all(all_inf);
604
+ if (tid % WARP_SIZE == 0) {
605
+ buf_iw[tid / WARP_SIZE] = all_inf;
606
+ }
607
+ __syncthreads();
608
+ all_inf = buf_iw[tid % WARP_SIZE];
609
+ __syncthreads();
610
+ all_inf = warp_reduce_all(all_inf);
508
611
 
509
- constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
510
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
511
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
512
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
513
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
514
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
515
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
516
- nullptr;
612
+ if (!all_inf) {
613
+ break;
614
+ }
615
+ }
616
+
617
+ // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
618
+ // If the break was triggered it's the lower edge of the tile with the first non-masked values.
619
+ // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
620
+ KV_max_sj += FATTN_KQ_STRIDE;
621
+
622
+ if (threadIdx.x != 0) {
623
+ return;
624
+ }
625
+
626
+ KV_max[sequence*ne31 + jt] = KV_max_sj;
517
627
  }
518
628
 
519
629
  template<int D, int ncols1, int ncols2> // D == head size
520
630
  __launch_bounds__(D, 1)
521
631
  static __global__ void flash_attn_stream_k_fixup(
522
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
632
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
633
+ const int nbatch_fa) {
523
634
  constexpr int ncols = ncols1*ncols2;
524
635
 
525
636
  const int bidx0 = blockIdx.x;
@@ -530,11 +641,11 @@ static __global__ void flash_attn_stream_k_fixup(
530
641
 
531
642
  const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
532
643
 
533
- const int iter_k = ne11 / FATTN_KQ_STRIDE;
534
- const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
644
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
645
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
535
646
 
536
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
537
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
647
+ const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
648
+ const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
538
649
 
539
650
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
540
651
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -543,14 +654,15 @@ static __global__ void flash_attn_stream_k_fixup(
543
654
  return;
544
655
  }
545
656
 
546
- const int channel = kbc0 / (iter_k*iter_j);
547
- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
657
+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
658
+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
659
+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
548
660
 
549
661
  if (jt*ncols1 + j >= ne01) {
550
662
  return;
551
663
  }
552
664
 
553
- dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
665
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
554
666
 
555
667
  // Load the partial result that needs a fixup:
556
668
  float dst_val = 0.0f;
@@ -569,7 +681,7 @@ static __global__ void flash_attn_stream_k_fixup(
569
681
  int bidx = bidx0 - 1;
570
682
  int kbc_stop = kbc0;
571
683
  while(true) {
572
- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
684
+ const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
573
685
  if (kbc == kbc_stop) { // Did not have any data.
574
686
  bidx--;
575
687
  kbc_stop = kbc;
@@ -607,24 +719,37 @@ static __global__ void flash_attn_stream_k_fixup(
607
719
  }
608
720
 
609
721
  template<int D> // D == head size
610
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
611
722
  __launch_bounds__(D, 1)
612
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
613
723
  static __global__ void flash_attn_combine_results(
614
724
  const float * __restrict__ VKQ_parts,
615
725
  const float2 * __restrict__ VKQ_meta,
616
726
  float * __restrict__ dst,
617
727
  const int parallel_blocks) {
618
- VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
619
- VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
620
- dst += D * gridDim.z*blockIdx.x;
728
+ // Dimension 0: threadIdx.x
729
+ // Dimension 1: blockIdx.x
730
+ // Dimension 2: blockIdx.y
731
+ // Dimension 3: blockIdx.z
732
+ // Memory layout is permuted with [0, 2, 1, 3]
733
+
734
+ const int ne01 = gridDim.x;
735
+ const int ne02 = gridDim.y;
736
+
737
+ const int col = blockIdx.x;
738
+ const int head = blockIdx.y;
739
+ const int sequence = blockIdx.z;
740
+
741
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
742
+
743
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
744
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
745
+ dst += j_dst_unrolled * D;
621
746
 
622
747
  const int tid = threadIdx.x;
623
748
  __builtin_assume(tid < D);
624
749
 
625
750
  extern __shared__ float2 meta[];
626
751
  for (int i = tid; i < 2*parallel_blocks; i += D) {
627
- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
752
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
628
753
  }
629
754
 
630
755
  __syncthreads();
@@ -637,44 +762,19 @@ static __global__ void flash_attn_combine_results(
637
762
  float VKQ_numerator = 0.0f;
638
763
  float VKQ_denominator = 0.0f;
639
764
  for (int l = 0; l < parallel_blocks; ++l) {
640
- const float diff = meta[l].x - kqmax;
641
- float KQ_max_scale = expf(diff);
642
- const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
643
- *((uint32_t *) &KQ_max_scale) &= ftz_mask;
765
+ const float KQ_max_scale = expf(meta[l].x - kqmax);
644
766
 
645
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
767
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
646
768
  VKQ_denominator += KQ_max_scale * meta[l].y;
647
769
  }
648
770
 
649
- dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
650
- }
651
-
652
- [[noreturn]]
653
- static void on_no_fattn_vec_case(const int D) {
654
- if (D == 64) {
655
- fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
656
- fprintf(stderr, "By default only f16 KV cache is supported.\n");
657
- fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
658
- GGML_ABORT("fatal error");
659
- } else if (D == 128) {
660
- fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
661
- fprintf(stderr, "Supported combinations:\n");
662
- fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
663
- fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
664
- fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
665
- fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
666
- GGML_ABORT("fatal error");
667
- } else {
668
- fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
669
- fprintf(stderr, "Only f16 is supported.\n");
670
- GGML_ABORT("fatal error");
671
- }
771
+ dst[tid] = VKQ_numerator / VKQ_denominator;
672
772
  }
673
773
 
674
774
  template <int DV, int ncols1, int ncols2>
675
775
  void launch_fattn(
676
776
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
677
- const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
777
+ const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
678
778
  ) {
679
779
  constexpr int ncols = ncols1 * ncols2;
680
780
 
@@ -686,7 +786,8 @@ void launch_fattn(
686
786
 
687
787
  GGML_ASSERT(V || is_mla);
688
788
 
689
- const ggml_tensor * mask = dst->src[3];
789
+ const ggml_tensor * mask = dst->src[3];
790
+ const ggml_tensor * sinks = dst->src[4];
690
791
 
691
792
  ggml_tensor * KQV = dst;
692
793
 
@@ -698,12 +799,6 @@ void launch_fattn(
698
799
  GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
699
800
 
700
801
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
701
- GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
702
- "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
703
-
704
- GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
705
-
706
- GGML_ASSERT(Q->ne[3] == 1);
707
802
 
708
803
  ggml_cuda_pool & pool = ctx.pool();
709
804
  cudaStream_t main_stream = ctx.stream();
@@ -713,6 +808,7 @@ void launch_fattn(
713
808
 
714
809
  ggml_cuda_pool_alloc<half> K_f16(pool);
715
810
  ggml_cuda_pool_alloc<half> V_f16(pool);
811
+ ggml_cuda_pool_alloc<int> KV_max(pool);
716
812
  ggml_cuda_pool_alloc<float> dst_tmp(pool);
717
813
  ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
718
814
 
@@ -727,43 +823,87 @@ void launch_fattn(
727
823
  size_t nb23 = V ? V->nb[3] : nb13;
728
824
 
729
825
  if (need_f16_K && K->type != GGML_TYPE_F16) {
730
- GGML_ASSERT(ggml_is_contiguously_allocated(K));
731
- K_f16.alloc(ggml_nelements(K));
732
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
733
- to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
734
- K_data = (char *) K_f16.ptr;
735
-
736
826
  const size_t bs = ggml_blck_size(K->type);
737
827
  const size_t ts = ggml_type_size(K->type);
738
828
 
739
- nb11 = nb11*bs*sizeof(half)/ts;
740
- nb12 = nb12*bs*sizeof(half)/ts;
741
- nb13 = nb13*bs*sizeof(half)/ts;
829
+ K_f16.alloc(ggml_nelements(K));
830
+ if (ggml_is_contiguously_allocated(K)) {
831
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
832
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
833
+
834
+ nb11 = nb11*bs*sizeof(half)/ts;
835
+ nb12 = nb12*bs*sizeof(half)/ts;
836
+ nb13 = nb13*bs*sizeof(half)/ts;
837
+ } else {
838
+ GGML_ASSERT(K->nb[0] == ts);
839
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
840
+ const int64_t s01 = nb11 / ts;
841
+ const int64_t s02 = nb12 / ts;
842
+ const int64_t s03 = nb13 / ts;
843
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
844
+
845
+ nb11 = K->ne[0] * sizeof(half);
846
+ nb12 = K->ne[1] * nb11;
847
+ nb13 = K->ne[2] * nb12;
848
+ }
849
+ K_data = (char *) K_f16.ptr;
742
850
  }
743
851
 
744
852
  if (V && need_f16_V && V->type != GGML_TYPE_F16) {
745
- GGML_ASSERT(ggml_is_contiguously_allocated(V));
746
- V_f16.alloc(ggml_nelements(V));
747
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
748
- to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
749
- V_data = (char *) V_f16.ptr;
750
-
751
853
  const size_t bs = ggml_blck_size(V->type);
752
854
  const size_t ts = ggml_type_size(V->type);
753
855
 
754
- nb21 = nb21*bs*sizeof(half)/ts;
755
- nb22 = nb22*bs*sizeof(half)/ts;
756
- nb23 = nb23*bs*sizeof(half)/ts;
856
+ V_f16.alloc(ggml_nelements(V));
857
+ if (ggml_is_contiguously_allocated(V)) {
858
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
859
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
860
+ V_data = (char *) V_f16.ptr;
861
+
862
+ nb21 = nb21*bs*sizeof(half)/ts;
863
+ nb22 = nb22*bs*sizeof(half)/ts;
864
+ nb23 = nb23*bs*sizeof(half)/ts;
865
+ } else {
866
+ GGML_ASSERT(V->nb[0] == ts);
867
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
868
+ const int64_t s01 = nb21 / ts;
869
+ const int64_t s02 = nb22 / ts;
870
+ const int64_t s03 = nb23 / ts;
871
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
872
+
873
+ nb21 = V->ne[0] * sizeof(half);
874
+ nb22 = V->ne[1] * nb21;
875
+ nb23 = V->ne[2] * nb22;
876
+ }
877
+ V_data = (char *) V_f16.ptr;
757
878
  }
758
879
 
759
- int parallel_blocks = 1;
760
-
761
880
  const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
762
881
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
763
882
 
883
+ // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
884
+ // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
885
+ // multiple sequences of possibly different lengths.
886
+ if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
887
+ const int s31 = mask->nb[1] / sizeof(half2);
888
+ const int s33 = mask->nb[3] / sizeof(half2);
889
+
890
+ const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
891
+ const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
892
+
893
+ const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
894
+ const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
895
+
896
+ KV_max.alloc(ne_KV_max);
897
+ flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
898
+ ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
899
+ CUDA_CHECK(cudaGetLastError());
900
+ }
901
+
764
902
  const dim3 block_dim(warp_size, nwarps, 1);
765
903
  int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
766
904
  CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
905
+ GGML_ASSERT(max_blocks_per_sm > 0);
906
+ int parallel_blocks = max_blocks_per_sm;
767
907
 
768
908
  dim3 blocks_num;
769
909
  if (stream_k) {
@@ -780,13 +920,11 @@ void launch_fattn(
780
920
  blocks_num.y = 1;
781
921
  blocks_num.z = 1;
782
922
 
783
- dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
923
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
924
+ dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
925
+ }
784
926
  } else {
785
- GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
786
- const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
787
-
788
- // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
789
- parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
927
+ const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
790
928
 
791
929
  // parallel_blocks must not be larger than what the tensor size allows:
792
930
  parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -802,7 +940,7 @@ void launch_fattn(
802
940
  const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
803
941
 
804
942
  // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
805
- if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
943
+ if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
806
944
  break;
807
945
  }
808
946
 
@@ -815,7 +953,7 @@ void launch_fattn(
815
953
 
816
954
  blocks_num.x = ntiles_x;
817
955
  blocks_num.y = parallel_blocks;
818
- blocks_num.z = Q->ne[2]*Q->ne[3];
956
+ blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
819
957
 
820
958
  if (parallel_blocks > 1) {
821
959
  dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -841,21 +979,24 @@ void launch_fattn(
841
979
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
842
980
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
843
981
 
982
+ // TODO other tensor dimensions after removal of WMMA kernel:
983
+ const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
984
+
844
985
  GGML_ASSERT(block_dim.x % warp_size == 0);
845
986
  fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
846
987
  (const char *) Q->data,
847
988
  K_data,
848
989
  V_data,
849
990
  mask ? ((const char *) mask->data) : nullptr,
991
+ sinks ? ((const char *) sinks->data) : nullptr,
992
+ KV_max.ptr,
850
993
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
851
994
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
852
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
853
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
854
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
855
- Q->nb[1], Q->nb[2], Q->nb[3],
856
- nb11, nb12, nb13,
995
+ Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
996
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
857
997
  nb21, nb22, nb23,
858
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
998
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
999
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
859
1000
  );
860
1001
  CUDA_CHECK(cudaGetLastError());
861
1002
 
@@ -866,11 +1007,11 @@ void launch_fattn(
866
1007
 
867
1008
  flash_attn_stream_k_fixup<DV, ncols1, ncols2>
868
1009
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
869
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
1010
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
870
1011
  }
871
1012
  } else if (parallel_blocks > 1) {
872
1013
  const dim3 block_dim_combine(DV, 1, 1);
873
- const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
1014
+ const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
874
1015
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
875
1016
 
876
1017
  flash_attn_combine_results<DV>