whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -15,6 +15,8 @@ typedef void (* fattn_kernel_t)(
15
15
  const char * __restrict__ K,
16
16
  const char * __restrict__ V,
17
17
  const char * __restrict__ mask,
18
+ const char * __restrict__ sinks,
19
+ const int * __restrict__ KV_max,
18
20
  float * __restrict__ dst,
19
21
  float2 * __restrict__ dst_meta,
20
22
  const float scale,
@@ -23,300 +25,238 @@ typedef void (* fattn_kernel_t)(
23
25
  const float m1,
24
26
  const uint32_t n_head_log2,
25
27
  const float logit_softcap,
26
- const int ne00,
27
- const int ne01,
28
- const int ne02,
29
- const int ne03,
30
- const int ne10,
31
- const int ne11,
32
- const int ne12,
33
- const int ne13,
34
- const int ne31,
35
- const int nb31,
36
- const int nb01,
37
- const int nb02,
38
- const int nb03,
39
- const int nb11,
40
- const int nb12,
41
- const int nb13,
42
- const int nb21,
43
- const int nb22,
44
- const int nb23,
45
- const int ne0,
46
- const int ne1,
47
- const int ne2,
48
- const int ne3);
49
-
50
- typedef half (*vec_dot_KQ_f16_t)(
51
- const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
52
- typedef float (*vec_dot_KQ_f32_t)(
28
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
29
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
30
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
31
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
32
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
33
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
34
+ const int32_t nb31, const int32_t nb32, const int64_t nb33);
35
+
36
+ typedef float (*vec_dot_KQ_t)(
53
37
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
54
38
 
55
- template<typename T, int D, int warp_size>
56
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
39
+ template <int D, int nthreads>
40
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
41
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
42
+
43
+ const half2 * K_h2 = (const half2 *) K_c;
44
+ GGML_UNUSED(Q_q8);
45
+ GGML_UNUSED(Q_ds_v);
46
+
47
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
48
+ constexpr int cpy_ne = cpy_nb / 4;
49
+
50
+ float sum = 0.0f;
51
+
52
+ #pragma unroll
53
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
54
+ half2 tmp[cpy_ne];
55
+ ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
56
+ #pragma unroll
57
+ for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
58
+ #ifdef FAST_FP16_AVAILABLE
59
+ ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
60
+ #else
61
+ ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
62
+ #endif // FP16_AVAILABLE
63
+ }
64
+ }
65
+
66
+ return sum;
67
+ }
68
+
69
+ template<int D, int nthreads>
70
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
57
71
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
58
72
 
59
73
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60
74
  GGML_UNUSED(Q_v);
61
75
 
62
- T sum = 0.0f;
76
+ float sum = 0.0f;
63
77
 
64
78
  #pragma unroll
65
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
66
- const int k_KQ = k_KQ_0 + threadIdx.x;
79
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
80
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
67
81
 
68
82
  const int ib = k_KQ / QI8_1;
69
83
  const int iqs4 = k_KQ % QI4_0;
70
84
  const int shift = k_KQ & (QI8_1/2);
71
85
 
72
- const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
73
- const int u = Q_q8[k_KQ_0/warp_size];
86
+ int v;
87
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
88
+ v = (v >> shift) & 0x0F0F0F0F;
89
+ const int u = Q_q8[k_KQ_0/nthreads];
74
90
 
75
91
  const int sumi = ggml_cuda_dp4a(v, u, 0);
76
92
 
77
- #ifdef FP16_AVAILABLE
78
- if (std::is_same<T, half>::value) {
79
- const half2 * Q_ds = (const half2 *) Q_ds_v;
80
-
81
- const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
82
- sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
83
- } else
84
- #endif // FP16_AVAILABLE
85
- {
86
- const float2 * Q_ds = (const float2 *) Q_ds_v;
87
-
88
- sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
89
- }
93
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
94
+ sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
90
95
  }
91
96
 
92
97
  return sum;
93
98
  }
94
99
 
95
- template<typename T, int D, int warp_size>
96
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
100
+ template<int D, int nthreads>
101
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
97
102
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
98
103
 
99
104
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
100
105
  GGML_UNUSED(Q_v);
101
106
 
102
- T sum = 0.0f;
107
+ float sum = 0.0f;
103
108
 
104
109
  #pragma unroll
105
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
106
- const int k_KQ = k_KQ_0 + threadIdx.x;
110
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
111
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
107
112
 
108
113
  const int ib = k_KQ / QI8_1;
109
114
  const int iqs4 = k_KQ % QI4_1;
110
115
  const int shift = k_KQ & (QI8_1/2);
111
116
 
112
- const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
113
- const int u = Q_q8[k_KQ_0/warp_size];
117
+ int v;
118
+ ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
119
+ v = (v >> shift) & 0x0F0F0F0F;
120
+ const int u = Q_q8[k_KQ_0/nthreads];
114
121
 
115
122
  const int sumi = ggml_cuda_dp4a(v, u, 0);
116
123
 
117
- #ifdef FP16_AVAILABLE
118
- if (std::is_same<T, half>::value) {
119
- const half2 * Q_ds = (const half2 *) Q_ds_v;
120
-
121
- const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
122
- const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
123
- sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
124
- } else
125
- #endif // FP16_AVAILABLE
126
- {
127
- const float2 * Q_ds = (const float2 *) Q_ds_v;
128
-
129
- const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
130
- const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
124
+ const float2 K_dm = __half22float2(K_q4_1[ib].dm);
125
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
131
126
 
132
- sum += (T) (sumid4d8 + m4s8scaled);
133
- }
127
+ sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
134
128
  }
135
129
 
136
130
  return sum;
137
131
  }
138
132
 
139
- template<typename T, int D, int warp_size>
140
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
133
+ template<int D, int nthreads>
134
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
141
135
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
142
136
 
143
137
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
144
138
  GGML_UNUSED(Q_v);
145
139
 
146
- T sum = 0.0f;
140
+ float sum = 0.0f;
147
141
 
148
142
  #pragma unroll
149
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
150
- const int k_KQ = k_KQ_0 + threadIdx.x;
143
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
144
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
151
145
 
152
146
  const int ib = k_KQ / QI8_1;
153
147
  const int iqs4 = k_KQ % QI5_0;
154
148
  const int iqs8 = k_KQ % QI8_1;
155
149
  const int shift = k_KQ & (QI8_1/2);
156
150
 
157
- int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
158
- const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
159
- v |= (vh << 4) & 0x00000010; // 0 -> 4
160
- v |= (vh << 11) & 0x00001000; // 1 -> 12
161
- v |= (vh << 18) & 0x00100000; // 2 -> 20
162
- v |= (vh << 25) & 0x10000000; // 3 -> 28
151
+ int v;
152
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
153
+ v = (v >> shift) & 0x0F0F0F0F;
163
154
 
164
- const int u = Q_q8[k_KQ_0/warp_size];
155
+ {
156
+ int vh;
157
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
158
+ vh >>= iqs8 * QI5_0;
159
+
160
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
161
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
162
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
163
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
164
+ }
165
165
 
166
- const int sumi = ggml_cuda_dp4a(v, u, 0);
166
+ const int u = Q_q8[k_KQ_0/nthreads];
167
167
 
168
- #ifdef FP16_AVAILABLE
169
- if (std::is_same<T, half>::value) {
170
- const half2 * Q_ds = (const half2 *) Q_ds_v;
168
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
171
169
 
172
- const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
173
- sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
174
- } else
175
- #endif // FP16_AVAILABLE
176
- {
177
- const float2 * Q_ds = (const float2 *) Q_ds_v;
170
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
178
171
 
179
- sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
180
- }
172
+ sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
181
173
  }
182
174
 
183
175
  return sum;
184
176
  }
185
177
 
186
- template<typename T, int D, int warp_size>
187
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
178
+ template<int D, int nthreads>
179
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
188
180
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189
181
 
190
182
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
191
183
  GGML_UNUSED(Q_v);
192
184
 
193
- T sum = 0.0f;
185
+ float sum = 0.0f;
194
186
 
195
187
  #pragma unroll
196
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
197
- const int k_KQ = k_KQ_0 + threadIdx.x;
188
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
189
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
198
190
 
199
191
  const int ib = k_KQ / QI8_1;
200
192
  const int iqs4 = k_KQ % QI5_1;
201
193
  const int iqs8 = k_KQ % QI8_1;
202
194
  const int shift = k_KQ & (QI8_1/2);
203
195
 
204
- int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
205
- const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
206
- v |= (vh << 4) & 0x00000010; // 0 -> 4
207
- v |= (vh << 11) & 0x00001000; // 1 -> 12
208
- v |= (vh << 18) & 0x00100000; // 2 -> 20
209
- v |= (vh << 25) & 0x10000000; // 3 -> 28
196
+ int v;
197
+ ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
198
+ v = (v >> shift) & 0x0F0F0F0F;
210
199
 
211
- const int u = Q_q8[k_KQ_0/warp_size];
212
-
213
- const int sumi = ggml_cuda_dp4a(v, u, 0);
200
+ {
201
+ int vh;
202
+ ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
203
+ vh >>= iqs8 * QI5_0;
204
+
205
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
206
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
207
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
208
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
209
+ }
214
210
 
215
- #ifdef FP16_AVAILABLE
216
- if (std::is_same<T, half>::value) {
217
- const half2 * Q_ds = (const half2 *) Q_ds_v;
211
+ const int u = Q_q8[k_KQ_0/nthreads];
218
212
 
219
- const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
220
- const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
221
- sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
222
- } else
223
- #endif // FP16_AVAILABLE
224
- {
225
- const float2 * Q_ds = (const float2 *) Q_ds_v;
213
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
226
214
 
227
- const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
228
- const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
215
+ const float2 K_dm = __half22float2(K_q5_1[ib].dm);
216
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
229
217
 
230
- sum += (T) (sumid5d8 + m5s8scaled);
231
- }
218
+ sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
232
219
  }
233
220
 
234
221
  return sum;
235
222
  }
236
223
 
237
- template <typename T, int D, int warp_size>
238
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
224
+ template <int D, int nthreads>
225
+ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
239
226
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
240
227
 
241
228
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
242
229
  GGML_UNUSED(Q_v);
243
230
 
244
- T sum = 0.0f;
231
+ float sum = 0.0f;
245
232
 
246
233
  #pragma unroll
247
- for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
248
- const int k_KQ = k_KQ_0 + threadIdx.x;
234
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
235
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
249
236
 
250
237
  const int ib = k_KQ / QI8_0;
251
238
  const int iqs = k_KQ % QI8_0;
252
239
 
253
- const int v = get_int_b2(K_q8_0[ib].qs, iqs);
240
+ int v;
241
+ ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
254
242
 
255
- T Q_d;
256
- if (std::is_same<T, half>::value) {
257
- const half2 * Q_ds = (const half2 *) Q_ds_v;
258
- Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
259
- } else {
260
- const float2 * Q_ds = (const float2 *) Q_ds_v;
261
- Q_d = Q_ds[k_KQ_0/warp_size].x;
262
- }
243
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
244
+ const float Q_d = Q_ds[k_KQ_0/nthreads].x;
263
245
 
264
- sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
246
+ sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
265
247
  }
266
248
 
267
249
  return sum;
268
250
  }
269
251
 
270
- template <typename T, int D, int warp_size>
271
- static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272
- const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
273
-
274
- const half2 * K_h2 = (const half2 *) K_c;
275
- GGML_UNUSED(Q_q8);
276
- GGML_UNUSED(Q_ds_v);
277
-
278
- #ifdef FP16_AVAILABLE
279
- if (std::is_same<T, half>::value) {
280
- const half2 * Q_h2 = (const half2 *) Q_v;
281
-
282
- half2 sum2 = make_half2(0.0f, 0.0f);
283
-
284
- #pragma unroll
285
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
286
- const int k_KQ = k_KQ_0 + threadIdx.x;
287
-
288
- const half2 K_ik = K_h2[k_KQ];
289
- sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
290
- }
291
-
292
- return __low2half(sum2) + __high2half(sum2);
293
- }
294
- #endif // FP16_AVAILABLE
295
-
296
- const float2 * Q_f2 = (const float2 *) Q_v;
297
-
298
- float sum = 0.0f;
299
-
300
- #pragma unroll
301
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
302
- const int k_KQ = k_KQ_0 + threadIdx.x;
303
-
304
- const half2 K_ik = K_h2[k_KQ];
305
- sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
306
- sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
307
- }
308
-
309
- return sum;
310
- }
311
-
312
- template <typename Tds>
252
+ template <typename Tds, int ni>
313
253
  static __device__ __forceinline__ void quantize_q8_1_to_shared(
314
254
  const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
315
255
 
316
256
  float vals[sizeof(int)] = {0.0f};
317
257
  #pragma unroll
318
258
  for (int l = 0; l < int(sizeof(int)); ++l) {
319
- vals[l] = scale * x[4*threadIdx.x + l];
259
+ vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
320
260
  }
321
261
 
322
262
  float amax = fabsf(vals[0]);
@@ -344,7 +284,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
344
284
  }
345
285
 
346
286
  yq32[threadIdx.x] = q32;
347
- if (threadIdx.x % QI8_1 == 0) {
287
+ if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
348
288
  if (std::is_same<Tds, half2>::value) {
349
289
  ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
350
290
  } else {
@@ -353,173 +293,335 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
353
293
  }
354
294
  }
355
295
 
356
- typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
357
- typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
296
+ typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
297
+
298
+ template <typename T, int ne>
299
+ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
300
+ if constexpr (std::is_same_v<T, half>) {
301
+ ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
302
+ } else if constexpr (std::is_same_v<T, float>) {
303
+ static_assert(ne % 2 == 0, "bad ne");
304
+ half2 tmp[ne/2];
305
+ ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
306
+ float2 * dst_f2 = (float2 *) dst;
307
+ #pragma unroll
308
+ for (int l = 0; l < ne/2; ++l) {
309
+ dst_f2[l] = __half22float2(tmp[l]);
310
+ }
311
+ } else {
312
+ static_assert(std::is_same_v<T, void>, "unsupported type");
313
+ }
314
+ }
358
315
 
359
- template <typename T>
360
- static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
316
+ template <typename T, int ne>
317
+ static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
361
318
  const block_q4_0 * x = (const block_q4_0 *) vx;
362
319
 
363
- const int64_t ib = i / QK4_0;
364
- const int iqs = i % (QK4_0/2);
365
- const int shift = (i % QK4_0) / (QK4_0/2);
320
+ const int64_t ib = i0 / QK4_0;
321
+ const int iqs = i0 % (QK4_0/2);
322
+ const int shift = (i0 % QK4_0) / (QK4_0/2);
366
323
 
367
- const T d = x[ib].d;
368
- const int q0 = x[ib].qs[iqs];
369
- const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
324
+ int q;
325
+ static_assert(ne == 2 || ne == 4, "bad ne");
326
+ ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
327
+ q >>= 4*shift;
328
+ q &= 0x0F0F0F0F;
329
+ q = __vsubss4(q, 0x08080808);
330
+
331
+ const int8_t * q8 = (const int8_t *) &q;
370
332
 
371
333
  #ifdef FP16_AVAILABLE
372
- if (std::is_same<T, half>::value) {
373
- return ((half) d)*((half) q);
374
- }
334
+ if constexpr (std::is_same_v<T, half>) {
335
+ const half2 d = __half2half2(x[ib].d);
336
+
337
+ #pragma unroll
338
+ for (int l0 = 0; l0 < ne; l0 += 2) {
339
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
340
+ }
341
+ } else
375
342
  #endif // FP16_AVAILABLE
343
+ if constexpr (std::is_same_v<T, float>) {
344
+ const float d = x[ib].d;
376
345
 
377
- return ((float) d)*((float) q);
346
+ #pragma unroll
347
+ for (int l = 0; l < ne; ++l) {
348
+ ((float *) dst)[l] = d * q8[l];
349
+ }
350
+ } else {
351
+ static_assert(std::is_same_v<T, void>, "bad type");
352
+ }
378
353
  }
379
354
 
380
- template <typename T>
381
- static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
355
+ template <typename T, int ne>
356
+ static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
382
357
  const block_q4_1 * x = (const block_q4_1 *) vx;
383
358
 
384
- const int64_t ib = i / QK4_1;
385
- const int iqs = i % (QK4_1/2);
386
- const int shift = (i % QK4_1) / (QK4_1/2);
359
+ const int64_t ib = i0 / QK4_1;
360
+ const int iqs = i0 % (QK4_1/2);
361
+ const int shift = (i0 % QK4_1) / (QK4_1/2);
387
362
 
388
- const half2 dm = x[ib].dm;
389
- const int q0 = x[ib].qs[iqs];
390
- const int q = ((q0 >> (4*shift)) & 0x0F);
363
+ int q;
364
+ static_assert(ne == 2 || ne == 4, "bad ne");
365
+ ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
366
+ q >>= 4*shift;
367
+ q &= 0x0F0F0F0F;
368
+
369
+ const int8_t * q8 = (const int8_t *) &q;
391
370
 
392
371
  #ifdef FP16_AVAILABLE
393
- if (std::is_same<T, half>::value) {
394
- return __low2half(dm)*((half) q) + __high2half(dm);
395
- }
372
+ if constexpr (std::is_same_v<T, half>) {
373
+ const half2 dm = x[ib].dm;
374
+ const half2 d = __half2half2( __low2half(dm));
375
+ const half2 m = __half2half2(__high2half(dm));
376
+
377
+ #pragma unroll
378
+ for (int l0 = 0; l0 < ne; l0 += 2) {
379
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
380
+ }
381
+ } else
396
382
  #endif // FP16_AVAILABLE
383
+ if constexpr (std::is_same_v<T, float>) {
384
+ const float2 dm = __half22float2(x[ib].dm);
397
385
 
398
- return __low2float(dm)*((float) q) + __high2float(dm);
386
+ #pragma unroll
387
+ for (int l = 0; l < ne; ++l) {
388
+ ((float *) dst)[l] = dm.x * q8[l] + dm.y;
389
+ }
390
+ } else {
391
+ static_assert(std::is_same_v<T, void>, "bad type");
392
+ }
399
393
  }
400
394
 
401
- template <typename T>
402
- static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
395
+ template <typename T, int ne>
396
+ static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
403
397
  const block_q5_0 * x = (const block_q5_0 *) vx;
404
398
 
405
- const int64_t ib = i / QK5_0;
406
- const int idq = i % QK5_0;
407
- const int iqs = i % (QK5_0/2);
408
- const int shift = (i % QK5_0) / (QK5_0/2);
399
+ const int64_t ib = i0 / QK5_0;
400
+ const int idq = i0 % QK5_0;
401
+ const int iqs = i0 % (QK5_0/2);
402
+ const int shift = (i0 % QK5_0) / (QK5_0/2);
409
403
 
410
- const T d = x[ib].d;
411
- const int ql0 = x[ib].qs[iqs];
412
- const int qh0 = get_int_b2(x[ib].qh, 0);
413
- const int ql = ((ql0 >> (4*shift)) & 0x0F);
414
- const int qh = ((qh0 >> idq) << 4) & 0x10;
415
- const int q = (ql | qh) - 16;
404
+ int q;
405
+ static_assert(ne == 2 || ne == 4, "bad ne");
406
+ ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
407
+ q >>= 4*shift;
408
+ q &= 0x0F0F0F0F;
416
409
 
417
- #ifdef FP16_AVAILABLE
418
- if (std::is_same<T, half>::value) {
419
- return ((half) d)*((half) q);
410
+ {
411
+ int qh;
412
+ ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
413
+ #pragma unroll
414
+ for (int l = 0; l < ne; ++l) {
415
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
416
+ }
420
417
  }
418
+
419
+ q = __vsubss4(q, 0x10101010);
420
+
421
+ const int8_t * q8 = (const int8_t *) &q;
422
+
423
+ #ifdef FP16_AVAILABLE
424
+ if constexpr (std::is_same_v<T, half>) {
425
+ const half2 d = __half2half2(x[ib].d);
426
+
427
+ #pragma unroll
428
+ for (int l0 = 0; l0 < ne; l0 += 2) {
429
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
430
+ }
431
+ } else
421
432
  #endif // FP16_AVAILABLE
433
+ if constexpr (std::is_same_v<T, float>) {
434
+ const float d = x[ib].d;
422
435
 
423
- return ((float) d)*((float) q);
436
+ #pragma unroll
437
+ for (int l = 0; l < ne; ++l) {
438
+ ((float *) dst)[l] = d * q8[l];
439
+ }
440
+ } else {
441
+ static_assert(std::is_same_v<T, void>, "bad type");
442
+ }
424
443
  }
425
444
 
426
- template <typename T>
427
- static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
445
+ template <typename T, int ne>
446
+ static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
428
447
  const block_q5_1 * x = (const block_q5_1 *) vx;
429
448
 
430
- const int64_t ib = i / QK5_1;
431
- const int idq = i % QK5_1;
432
- const int iqs = i % (QK5_1/2);
433
- const int shift = (i % QK5_1) / (QK5_1/2);
449
+ const int64_t ib = i0 / QK5_1;
450
+ const int idq = i0 % QK5_1;
451
+ const int iqs = i0 % (QK5_1/2);
452
+ const int shift = (i0 % QK5_1) / (QK5_1/2);
434
453
 
435
- const half2 dm = x[ib].dm;
436
- const int ql0 = x[ib].qs[iqs];
437
- const int qh0 = get_int_b4(x[ib].qh, 0);
438
- const int ql = ((ql0 >> (4*shift)) & 0x0F);
439
- const int qh = ((qh0 >> idq) << 4) & 0x10;
440
- const int q = (ql | qh);
454
+ int q;
455
+ static_assert(ne == 2 || ne == 4, "bad ne");
456
+ ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
457
+ q >>= 4*shift;
458
+ q &= 0x0F0F0F0F;
441
459
 
442
- #ifdef FP16_AVAILABLE
443
- if (std::is_same<T, half>::value) {
444
- return __low2half(dm)*((half) q) + __high2half(dm);
460
+ {
461
+ int qh;
462
+ ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
463
+ #pragma unroll
464
+ for (int l = 0; l < ne; ++l) {
465
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
466
+ }
445
467
  }
468
+
469
+ const int8_t * q8 = (const int8_t *) &q;
470
+
471
+ #ifdef FP16_AVAILABLE
472
+ if constexpr (std::is_same_v<T, half>) {
473
+ const half2 dm = x[ib].dm;
474
+ const half2 d = __half2half2( __low2half(dm));
475
+ const half2 m = __half2half2(__high2half(dm));
476
+
477
+ #pragma unroll
478
+ for (int l0 = 0; l0 < ne; l0 += 2) {
479
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
480
+ }
481
+ } else
446
482
  #endif // FP16_AVAILABLE
483
+ if constexpr (std::is_same_v<T, float>) {
484
+ const float2 dm = __half22float2(x[ib].dm);
447
485
 
448
- return __low2float(dm)*((float) q) + __high2float(dm);
486
+ #pragma unroll
487
+ for (int l = 0; l < ne; ++l) {
488
+ ((float *) dst)[l] = dm.x * q8[l] + dm.y;
489
+ }
490
+ } else {
491
+ static_assert(std::is_same_v<T, void>, "bad type");
492
+ }
449
493
  }
450
494
 
451
- template <typename T>
452
- static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
495
+ template <typename T, int ne>
496
+ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
453
497
  const block_q8_0 * x = (const block_q8_0 *) vx;
454
498
 
455
- const int64_t ib = i / QK8_0;
456
- const int iqs = i % QK8_0;
499
+ const int64_t ib = i0 / QK8_0;
500
+ const int iqs = i0 % QK8_0;
457
501
 
458
- const T d = x[ib].d;
459
- const int q = x[ib].qs[iqs];
502
+ static_assert(ne % 2 == 0, "bad ne");
503
+ int8_t qs[ne];
504
+ ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
460
505
 
461
506
  #ifdef FP16_AVAILABLE
462
- if (std::is_same<T, half>::value) {
463
- return ((half) d)*((half) q);
464
- }
507
+ if constexpr (std::is_same<T, half>::value) {
508
+ const half2 d = __half2half2(x[ib].d);
509
+
510
+ #pragma unroll
511
+ for (int l0 = 0; l0 < ne; l0 += 2) {
512
+ ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
513
+ }
514
+ } else
465
515
  #endif // FP16_AVAILABLE
516
+ if constexpr (std::is_same<T, float>::value) {
517
+ const float d = x[ib].d;
466
518
 
467
- return ((float) d)*((float) q);
519
+ #pragma unroll
520
+ for (int l = 0; l < ne; ++l) {
521
+ ((float *) dst)[l] = d * qs[l];
522
+ }
523
+ } else {
524
+ static_assert(std::is_same_v<T, void>, "unsupported type");
525
+ }
468
526
  }
469
527
 
470
- template <typename T>
471
- static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
472
- const half * x = (const half *) vx;
473
-
474
- return x[i];
528
+ template <ggml_type type_K, int D, int nthreads>
529
+ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
530
+ if constexpr (type_K == GGML_TYPE_F16) {
531
+ return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
532
+ } else if constexpr (type_K == GGML_TYPE_Q4_0) {
533
+ return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
534
+ } else if constexpr (type_K == GGML_TYPE_Q4_1) {
535
+ return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
536
+ } else if constexpr (type_K == GGML_TYPE_Q5_0) {
537
+ return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
538
+ } else if constexpr (type_K == GGML_TYPE_Q5_1) {
539
+ return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
540
+ } else if constexpr (type_K == GGML_TYPE_Q8_0) {
541
+ return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
542
+ } else {
543
+ static_assert(type_K == -1, "bad type");
544
+ return nullptr;
545
+ }
475
546
  }
476
547
 
477
- template <int D, int warp_size = WARP_SIZE>
478
- constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
479
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
480
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
481
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
482
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
483
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
484
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
485
- nullptr;
548
+ template <ggml_type type_V, typename T, int ne>
549
+ constexpr __device__ dequantize_V_t get_dequantize_V() {
550
+ if constexpr (type_V == GGML_TYPE_F16) {
551
+ return dequantize_V_f16<T, ne>;
552
+ } else if constexpr (type_V == GGML_TYPE_Q4_0) {
553
+ return dequantize_V_q4_0<T, ne>;
554
+ } else if constexpr (type_V == GGML_TYPE_Q4_1) {
555
+ return dequantize_V_q4_1<T, ne>;
556
+ } else if constexpr (type_V == GGML_TYPE_Q5_0) {
557
+ return dequantize_V_q5_0<T, ne>;
558
+ } else if constexpr (type_V == GGML_TYPE_Q5_1) {
559
+ return dequantize_V_q5_1<T, ne>;
560
+ } else if constexpr (type_V == GGML_TYPE_Q8_0) {
561
+ return dequantize_V_q8_0<T, ne>;
562
+ } else {
563
+ static_assert(type_V == -1, "bad type");
564
+ return nullptr;
565
+ }
486
566
  }
487
567
 
488
- template <int D, int warp_size = WARP_SIZE>
489
- constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
490
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
491
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
492
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
493
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
494
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
495
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
496
- nullptr;
497
- }
568
+ template <int ncols1>
569
+ __launch_bounds__(FATTN_KQ_STRIDE/2, 1)
570
+ static __global__ void flash_attn_mask_to_KV_max(
571
+ const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
572
+ const int ne31 = gridDim.x;
573
+ const int tid = threadIdx.x;
574
+ const int sequence = blockIdx.y;
575
+ const int jt = blockIdx.x;
498
576
 
499
- constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
500
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
501
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
502
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
503
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
504
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
505
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
506
- nullptr;
507
- }
577
+ mask += sequence*s33 + jt*ncols1*s31;
578
+
579
+ __shared__ int buf_iw[WARP_SIZE];
580
+ if (tid < WARP_SIZE) {
581
+ buf_iw[tid] = 1;
582
+ }
583
+ __syncthreads();
508
584
 
509
- constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
510
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
511
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
512
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
513
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
514
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
515
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
516
- nullptr;
585
+ int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
586
+ for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
587
+ int all_inf = 1;
588
+
589
+ #pragma unroll
590
+ for (int j = 0; j < ncols1; ++j) {
591
+ const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
592
+ all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
593
+ }
594
+
595
+ all_inf = warp_reduce_all(all_inf);
596
+ if (tid % WARP_SIZE == 0) {
597
+ buf_iw[tid / WARP_SIZE] = all_inf;
598
+ }
599
+ __syncthreads();
600
+ all_inf = buf_iw[tid % WARP_SIZE];
601
+ __syncthreads();
602
+ all_inf = warp_reduce_all(all_inf);
603
+
604
+ if (!all_inf) {
605
+ break;
606
+ }
607
+ }
608
+
609
+ // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
610
+ // If the break was triggered it's the lower edge of the tile with the first non-masked values.
611
+ // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
612
+ KV_max_sj += FATTN_KQ_STRIDE;
613
+
614
+ if (threadIdx.x != 0) {
615
+ return;
616
+ }
617
+
618
+ KV_max[sequence*ne31 + jt] = KV_max_sj;
517
619
  }
518
620
 
519
621
  template<int D, int ncols1, int ncols2> // D == head size
520
622
  __launch_bounds__(D, 1)
521
623
  static __global__ void flash_attn_stream_k_fixup(
522
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
624
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
523
625
  constexpr int ncols = ncols1*ncols2;
524
626
 
525
627
  const int bidx0 = blockIdx.x;
@@ -533,8 +635,8 @@ static __global__ void flash_attn_stream_k_fixup(
533
635
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
534
636
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
535
637
 
536
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
537
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
638
+ const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
639
+ const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
538
640
 
539
641
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
540
642
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -543,14 +645,15 @@ static __global__ void flash_attn_stream_k_fixup(
543
645
  return;
544
646
  }
545
647
 
546
- const int channel = kbc0 / (iter_k*iter_j);
547
- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
648
+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
649
+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
650
+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
548
651
 
549
652
  if (jt*ncols1 + j >= ne01) {
550
653
  return;
551
654
  }
552
655
 
553
- dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
656
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
554
657
 
555
658
  // Load the partial result that needs a fixup:
556
659
  float dst_val = 0.0f;
@@ -569,7 +672,7 @@ static __global__ void flash_attn_stream_k_fixup(
569
672
  int bidx = bidx0 - 1;
570
673
  int kbc_stop = kbc0;
571
674
  while(true) {
572
- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
675
+ const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
573
676
  if (kbc == kbc_stop) { // Did not have any data.
574
677
  bidx--;
575
678
  kbc_stop = kbc;
@@ -607,24 +710,37 @@ static __global__ void flash_attn_stream_k_fixup(
607
710
  }
608
711
 
609
712
  template<int D> // D == head size
610
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
611
713
  __launch_bounds__(D, 1)
612
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
613
714
  static __global__ void flash_attn_combine_results(
614
715
  const float * __restrict__ VKQ_parts,
615
716
  const float2 * __restrict__ VKQ_meta,
616
717
  float * __restrict__ dst,
617
718
  const int parallel_blocks) {
618
- VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
619
- VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
620
- dst += D * gridDim.z*blockIdx.x;
719
+ // Dimension 0: threadIdx.x
720
+ // Dimension 1: blockIdx.x
721
+ // Dimension 2: blockIdx.y
722
+ // Dimension 3: blockIdx.z
723
+ // Memory layout is permuted with [0, 2, 1, 3]
724
+
725
+ const int ne01 = gridDim.x;
726
+ const int ne02 = gridDim.y;
727
+
728
+ const int col = blockIdx.x;
729
+ const int head = blockIdx.y;
730
+ const int sequence = blockIdx.z;
731
+
732
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
733
+
734
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
735
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
736
+ dst += j_dst_unrolled * D;
621
737
 
622
738
  const int tid = threadIdx.x;
623
739
  __builtin_assume(tid < D);
624
740
 
625
741
  extern __shared__ float2 meta[];
626
742
  for (int i = tid; i < 2*parallel_blocks; i += D) {
627
- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
743
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
628
744
  }
629
745
 
630
746
  __syncthreads();
@@ -637,38 +753,13 @@ static __global__ void flash_attn_combine_results(
637
753
  float VKQ_numerator = 0.0f;
638
754
  float VKQ_denominator = 0.0f;
639
755
  for (int l = 0; l < parallel_blocks; ++l) {
640
- const float diff = meta[l].x - kqmax;
641
- float KQ_max_scale = expf(diff);
642
- const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
643
- *((uint32_t *) &KQ_max_scale) &= ftz_mask;
756
+ const float KQ_max_scale = expf(meta[l].x - kqmax);
644
757
 
645
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
758
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
646
759
  VKQ_denominator += KQ_max_scale * meta[l].y;
647
760
  }
648
761
 
649
- dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
650
- }
651
-
652
- [[noreturn]]
653
- static void on_no_fattn_vec_case(const int D) {
654
- if (D == 64) {
655
- fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
656
- fprintf(stderr, "By default only f16 KV cache is supported.\n");
657
- fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
658
- GGML_ABORT("fatal error");
659
- } else if (D == 128) {
660
- fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
661
- fprintf(stderr, "Supported combinations:\n");
662
- fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
663
- fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
664
- fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
665
- fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
666
- GGML_ABORT("fatal error");
667
- } else {
668
- fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
669
- fprintf(stderr, "Only f16 is supported.\n");
670
- GGML_ABORT("fatal error");
671
- }
762
+ dst[tid] = VKQ_numerator / VKQ_denominator;
672
763
  }
673
764
 
674
765
  template <int DV, int ncols1, int ncols2>
@@ -686,7 +777,8 @@ void launch_fattn(
686
777
 
687
778
  GGML_ASSERT(V || is_mla);
688
779
 
689
- const ggml_tensor * mask = dst->src[3];
780
+ const ggml_tensor * mask = dst->src[3];
781
+ const ggml_tensor * sinks = dst->src[4];
690
782
 
691
783
  ggml_tensor * KQV = dst;
692
784
 
@@ -703,8 +795,6 @@ void launch_fattn(
703
795
 
704
796
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
705
797
 
706
- GGML_ASSERT(Q->ne[3] == 1);
707
-
708
798
  ggml_cuda_pool & pool = ctx.pool();
709
799
  cudaStream_t main_stream = ctx.stream();
710
800
  const int id = ggml_cuda_get_device();
@@ -713,6 +803,7 @@ void launch_fattn(
713
803
 
714
804
  ggml_cuda_pool_alloc<half> K_f16(pool);
715
805
  ggml_cuda_pool_alloc<half> V_f16(pool);
806
+ ggml_cuda_pool_alloc<int> KV_max(pool);
716
807
  ggml_cuda_pool_alloc<float> dst_tmp(pool);
717
808
  ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
718
809
 
@@ -727,43 +818,86 @@ void launch_fattn(
727
818
  size_t nb23 = V ? V->nb[3] : nb13;
728
819
 
729
820
  if (need_f16_K && K->type != GGML_TYPE_F16) {
730
- GGML_ASSERT(ggml_is_contiguously_allocated(K));
731
- K_f16.alloc(ggml_nelements(K));
732
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
733
- to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
734
- K_data = (char *) K_f16.ptr;
735
-
736
821
  const size_t bs = ggml_blck_size(K->type);
737
822
  const size_t ts = ggml_type_size(K->type);
738
823
 
739
- nb11 = nb11*bs*sizeof(half)/ts;
740
- nb12 = nb12*bs*sizeof(half)/ts;
741
- nb13 = nb13*bs*sizeof(half)/ts;
824
+ K_f16.alloc(ggml_nelements(K));
825
+ if (ggml_is_contiguously_allocated(K)) {
826
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
827
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
828
+
829
+ nb11 = nb11*bs*sizeof(half)/ts;
830
+ nb12 = nb12*bs*sizeof(half)/ts;
831
+ nb13 = nb13*bs*sizeof(half)/ts;
832
+ } else {
833
+ GGML_ASSERT(K->nb[0] == ts);
834
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
835
+ const int64_t s01 = nb11 / ts;
836
+ const int64_t s02 = nb12 / ts;
837
+ const int64_t s03 = nb13 / ts;
838
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
839
+
840
+ nb11 = K->ne[0] * sizeof(half);
841
+ nb12 = K->ne[1] * nb11;
842
+ nb13 = K->ne[2] * nb12;
843
+ }
844
+ K_data = (char *) K_f16.ptr;
742
845
  }
743
846
 
744
847
  if (V && need_f16_V && V->type != GGML_TYPE_F16) {
745
- GGML_ASSERT(ggml_is_contiguously_allocated(V));
746
- V_f16.alloc(ggml_nelements(V));
747
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
748
- to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
749
- V_data = (char *) V_f16.ptr;
750
-
751
848
  const size_t bs = ggml_blck_size(V->type);
752
849
  const size_t ts = ggml_type_size(V->type);
753
850
 
754
- nb21 = nb21*bs*sizeof(half)/ts;
755
- nb22 = nb22*bs*sizeof(half)/ts;
756
- nb23 = nb23*bs*sizeof(half)/ts;
851
+ V_f16.alloc(ggml_nelements(V));
852
+ if (ggml_is_contiguously_allocated(V)) {
853
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
854
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
855
+ V_data = (char *) V_f16.ptr;
856
+
857
+ nb21 = nb21*bs*sizeof(half)/ts;
858
+ nb22 = nb22*bs*sizeof(half)/ts;
859
+ nb23 = nb23*bs*sizeof(half)/ts;
860
+ } else {
861
+ GGML_ASSERT(V->nb[0] == ts);
862
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
863
+ const int64_t s01 = nb21 / ts;
864
+ const int64_t s02 = nb22 / ts;
865
+ const int64_t s03 = nb23 / ts;
866
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
867
+
868
+ nb21 = V->ne[0] * sizeof(half);
869
+ nb22 = V->ne[1] * nb21;
870
+ nb23 = V->ne[2] * nb22;
871
+ }
872
+ V_data = (char *) V_f16.ptr;
757
873
  }
758
874
 
759
- int parallel_blocks = 1;
760
-
761
875
  const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
762
876
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
763
877
 
878
+ // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
879
+ // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
880
+ // multiple sequences of possibly different lengths.
881
+ if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
882
+ const int s31 = mask->nb[1] / sizeof(half2);
883
+ const int s33 = mask->nb[3] / sizeof(half2);
884
+
885
+ const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
886
+ const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
887
+
888
+ const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
889
+ const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
890
+
891
+ KV_max.alloc(ne_KV_max);
892
+ flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
893
+ ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
894
+ CUDA_CHECK(cudaGetLastError());
895
+ }
896
+
764
897
  const dim3 block_dim(warp_size, nwarps, 1);
765
898
  int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
766
899
  CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
900
+ int parallel_blocks = max_blocks_per_sm;
767
901
 
768
902
  dim3 blocks_num;
769
903
  if (stream_k) {
@@ -785,9 +919,6 @@ void launch_fattn(
785
919
  GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
786
920
  const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
787
921
 
788
- // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
789
- parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
790
-
791
922
  // parallel_blocks must not be larger than what the tensor size allows:
792
923
  parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
793
924
 
@@ -802,7 +933,7 @@ void launch_fattn(
802
933
  const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
803
934
 
804
935
  // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
805
- if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
936
+ if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
806
937
  break;
807
938
  }
808
939
 
@@ -847,15 +978,15 @@ void launch_fattn(
847
978
  K_data,
848
979
  V_data,
849
980
  mask ? ((const char *) mask->data) : nullptr,
981
+ sinks ? ((const char *) sinks->data) : nullptr,
982
+ KV_max.ptr,
850
983
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
851
984
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
852
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
853
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
854
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
855
- Q->nb[1], Q->nb[2], Q->nb[3],
856
- nb11, nb12, nb13,
985
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
986
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
857
987
  nb21, nb22, nb23,
858
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
988
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
989
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
859
990
  );
860
991
  CUDA_CHECK(cudaGetLastError());
861
992
 
@@ -866,11 +997,11 @@ void launch_fattn(
866
997
 
867
998
  flash_attn_stream_k_fixup<DV, ncols1, ncols2>
868
999
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
869
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
1000
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
870
1001
  }
871
1002
  } else if (parallel_blocks > 1) {
872
1003
  const dim3 block_dim_combine(DV, 1, 1);
873
- const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
1004
+ const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
874
1005
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
875
1006
 
876
1007
  flash_attn_combine_results<DV>