whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -7,6 +7,7 @@
7
7
  #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
8
8
 
9
9
  #extension GL_KHR_shader_subgroup_basic : enable
10
+ #extension GL_KHR_shader_subgroup_arithmetic : enable
10
11
  #extension GL_KHR_shader_subgroup_vote : enable
11
12
  #extension GL_KHR_memory_scope_semantics : enable
12
13
  #extension GL_KHR_cooperative_matrix : enable
@@ -14,12 +15,12 @@
14
15
  #include "types.glsl"
15
16
  #include "flash_attn_base.glsl"
16
17
 
17
- const uint32_t HSK_per_thread = HSK / D_split;
18
- const uint32_t HSV_per_thread = HSV / D_split;
18
+ // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
19
+ const uint32_t MatBr = 16;
20
+ const uint32_t MatBc = 16;
19
21
 
20
- const uint32_t row_split = 4;
21
22
  const uint32_t rows_per_thread = Br / row_split;
22
- const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
23
+ const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
23
24
  const uint32_t cols_per_thread = Bc / cols_per_iter;
24
25
 
25
26
 
@@ -31,33 +32,28 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
31
32
  layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
32
33
  layout (binding = 3) readonly buffer M {float16_t data_m[];};
33
34
 
34
- // Store the output when doing grouped query attention.
35
- // Rows index by Q's dimension 2, and the first N rows are valid.
36
- D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
37
- {
38
- uint32_t offset = (iq2 + r) * HSV + c;
39
- data_o[o_offset + offset] = D_TYPE(elem);
40
- return elem;
41
- }
42
-
43
- // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
44
- const uint32_t MatBr = 16;
45
- const uint32_t MatBc = 16;
46
-
47
- shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
48
- shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
35
+ shared float tmpsh[row_split];
49
36
 
50
37
  const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
51
38
  shared f16vec4 Qf[Br * qstride];
52
39
 
40
+ const uint psh_stride = Br / 4 + 2;
41
+ shared f16vec4 Psh[Bc * psh_stride];
42
+
53
43
  // Avoid padding for hsk==256 to make it fit in 48KB shmem.
54
- const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
55
- shared ACC_TYPE sfsh[Bc * sfshstride];
44
+ const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
45
+ shared ACC_TYPEV4 sfsh[Bc * sfshstride];
56
46
 
57
- const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
58
- shared f16vec4 ksh[Bc * kshstride];
47
+ const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
48
+ const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
49
+ const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
50
+ const uint vsh_stride = v_cols;
51
+ shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
59
52
 
60
- shared float slope[Br];
53
+ const uint32_t osh_stride = row_split * MatBr / 4;
54
+ shared f16vec4 pvsh[MatBc * osh_stride];
55
+
56
+ shared ACC_TYPE slope[Br];
61
57
 
62
58
  void main() {
63
59
  #ifdef NEEDS_INIT_IQ_SHMEM
@@ -69,9 +65,9 @@ void main() {
69
65
  const uint32_t tid = gl_LocalInvocationIndex;
70
66
 
71
67
  const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
68
+ const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
72
69
  const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
73
- const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
74
- const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
70
+ const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
75
71
 
76
72
  #define tile_row(r) (row_tid * rows_per_thread + (r))
77
73
 
@@ -82,15 +78,10 @@ void main() {
82
78
  Qf[i + tid] = f16vec4(0);
83
79
  }
84
80
  }
85
- [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
86
- if (i + tid < Bc * kshstride) {
87
- ksh[i + tid] = f16vec4(0);
88
- }
89
- }
90
81
  barrier();
91
82
  }
92
83
 
93
- uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
84
+ uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
94
85
 
95
86
  [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
96
87
  uint32_t d = (idx + tid) % (HSK / 4);
@@ -102,10 +93,10 @@ void main() {
102
93
  }
103
94
  barrier();
104
95
 
105
- ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
106
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
107
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
108
- Of[r][d] = ACC_TYPEV4(0.0);
96
+ f16vec4 Of[rows_per_thread][d_per_thread];
97
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
98
+ [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
99
+ Of[r][d] = f16vec4(0.0);
109
100
  }
110
101
  }
111
102
 
@@ -125,15 +116,17 @@ void main() {
125
116
  uint r = tid;
126
117
  slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
127
118
  }
128
- barrier();
129
119
  } else {
130
120
  if (tid < Br) {
131
121
  uint r = tid;
132
- slope[r] = 1.0;
122
+ slope[r] = ACC_TYPE(1.0);
133
123
  }
134
- barrier();
135
124
  }
136
125
 
126
+ const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
127
+ // mo_offset will point to the tile starting at row i*Br and col 0
128
+ uint32_t mo_offset = mo_stride * i;
129
+
137
130
  #if BLOCK_SIZE > 1
138
131
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
139
132
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -141,65 +134,114 @@ void main() {
141
134
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
142
135
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
143
136
  #endif
144
- uint32_t m_offset = 0;
137
+ uint32_t m_offset = gqa_iq1*KV;
145
138
  if (p.nem2 != 1 || p.nem3 != 1) {
146
- m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
139
+ m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
140
+ mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
147
141
  }
148
142
 
143
+ uint32_t mask_opt = 0;
144
+ uint32_t mask_opt_idx = ~0;
145
+ uint32_t mask_opt_bits = 0;
146
+ f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
147
+
149
148
  [[dont_unroll]]
150
149
  for (uint32_t j = start_j; j < end_j; ++j) {
151
150
 
152
- float mask_cache[Bc * Br / WorkGroupSize];
153
- if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
154
- bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
155
-
156
- float max_mask = NEG_FLT_MAX_OVER_2;
157
- [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
158
- uint32_t c = (idx + tid) % Bc;
159
- uint32_t r = (idx + tid) / Bc;
160
- if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
161
- if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
162
- float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
163
- mask_cache[idx / WorkGroupSize] = m;
164
- max_mask = max(max_mask, m);
165
- }
166
- }
167
- }
168
- // skip the block if the mask is entirely -inf
169
- bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
170
- barrier();
171
- if (gl_SubgroupInvocationID == 0) {
172
- tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
173
- }
174
- barrier();
175
- [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
176
- max_mask = max(max_mask, tmpsh[s]);
151
+ [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
152
+ mask_cache[idx] = f16vec4(0);
153
+ }
154
+
155
+ if (MASK_ENABLE) {
156
+ if (USE_MASK_OPT && mask_opt_idx != j / 16) {
157
+ mask_opt_idx = j / 16;
158
+ mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
177
159
  }
178
- if (max_mask <= NEG_FLT_MAX_OVER_2) {
160
+ mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
161
+ if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
162
+ // skip this block
179
163
  continue;
180
164
  }
165
+ // Only load if the block is not all zeros
166
+ if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
167
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
168
+
169
+ float max_mask = NEG_FLT_MAX_OVER_2;
170
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
171
+ uint32_t c = (idx + tid) / (Br / 4);
172
+ uint32_t r = (idx + tid) % (Br / 4);
173
+ if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
174
+ if ((!KV_bounds_check || j * Bc + c < KV)) {
175
+ f16vec4 m;
176
+ if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
177
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
178
+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
179
+ data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
180
+ data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
181
+ max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
182
+ } else if (i * Br + r * 4 + 2 < p.nem1) {
183
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
184
+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
185
+ data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
186
+ 0.0);
187
+ max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
188
+ } else if (i * Br + r * 4 + 1 < p.nem1) {
189
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
190
+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
191
+ 0.0,
192
+ 0.0);
193
+ max_mask = max(max(max_mask, float(m[0])), float(m[1]));
194
+ } else if (i * Br + r * 4 < p.nem1) {
195
+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
196
+ 0.0,
197
+ 0.0,
198
+ 0.0);
199
+ max_mask = max(max_mask, float(m[0]));
200
+ } else {
201
+ m = f16vec4(0.0);
202
+ }
203
+ mask_cache[idx / WorkGroupSize] = m;
204
+ }
205
+ }
206
+ }
207
+ // skip the block if the mask is entirely -inf
208
+ bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
209
+ barrier();
210
+ if (gl_SubgroupInvocationID == 0) {
211
+ tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
212
+ }
213
+ barrier();
214
+ [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
215
+ max_mask = max(max_mask, tmpsh[s]);
216
+ }
217
+ if (max_mask <= NEG_FLT_MAX_OVER_2) {
218
+ continue;
219
+ }
220
+ }
181
221
  }
182
222
 
183
- [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
184
- uint32_t d = (idx + tid) % (HSK / 4);
185
- uint32_t c = (idx + tid) / (HSK / 4);
186
- if (c < Bc && d < HSK / 4) {
187
- f16vec4 K_Tf = f16vec4(0);
188
- if (!KV_bounds_check || j * Bc + c < KV) {
223
+ if (SHMEM_STAGING != 0) {
224
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
225
+ uint32_t d = (idx + tid) % (HSK_pad / 4);
226
+ uint32_t c = (idx + tid) / (HSK_pad / 4);
227
+ if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
228
+ f16vec4 K_Tf = f16vec4(0);
229
+ if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
189
230
  #if BLOCK_SIZE > 1
190
- uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
191
- uint ib = coord / BLOCK_SIZE;
192
- uint iqs = (coord % BLOCK_SIZE);
193
- K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
231
+ uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
232
+ uint ib = coord / BLOCK_SIZE;
233
+ uint iqs = (coord % BLOCK_SIZE);
234
+ K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
194
235
  #else
195
- K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
236
+ K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
196
237
  #endif
197
- }
238
+ }
198
239
 
199
- ksh[c * kshstride + d] = K_Tf;
240
+ kvsh[c * kvsh_stride + d] = K_Tf;
241
+ }
200
242
  }
243
+ barrier();
201
244
  }
202
- barrier();
203
245
 
204
246
  // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
205
247
  // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
@@ -208,11 +250,59 @@ void main() {
208
250
  coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
209
251
  coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
210
252
 
211
- for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
212
- coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
253
+ [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
254
+ // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
255
+ // If not, f16 K is loaded directly from global memory if aligned, otherwise
256
+ // staged through a Bc * MatBr size staging buffer.
257
+ // If K is not type f16, then it is always staged for dequantization.
258
+ if (SHMEM_STAGING == 0) {
259
+ #if BLOCK_SIZE == 1
260
+ if (KV_bounds_check || d * 16 + 16 > HSK) {
261
+ #endif
262
+ barrier();
263
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
264
+ uint32_t col_vec = (idx + tid) % (MatBr / 4);
265
+ uint32_t row = (idx + tid) / (MatBr / 4);
266
+ if (idx + tid < Bc * MatBr / 4) {
267
+ f16vec4 K_Tf = f16vec4(0);
268
+ if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
269
+ #if BLOCK_SIZE > 1
270
+ uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
271
+ uint ib = coord / BLOCK_SIZE;
272
+ uint iqs = (coord % BLOCK_SIZE);
273
+ K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
274
+ #else
275
+ K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
276
+ #endif
277
+ }
278
+
279
+ kvsh[row * kvsh_stride + col_vec] = K_Tf;
280
+ }
281
+ }
282
+ barrier();
283
+ #if BLOCK_SIZE == 1
284
+ }
285
+ #endif
286
+
287
+ #if BLOCK_SIZE == 1
288
+ if (KV_bounds_check || d * 16 + 16 > HSK)
289
+ #endif
290
+ {
291
+ uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
292
+ coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
293
+ }
294
+ #if BLOCK_SIZE == 1
295
+ else {
296
+ const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
297
+ coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
298
+ }
299
+ #endif
300
+ } else {
301
+ uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
302
+ coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
303
+ }
213
304
 
214
- uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
215
- coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
305
+ coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
216
306
 
217
307
  SfMat = coopMatMulAdd(KMat, QMat, SfMat);
218
308
  }
@@ -221,27 +311,27 @@ void main() {
221
311
  coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
222
312
  barrier();
223
313
 
224
- if (p.logit_softcap != 0.0f) {
225
- [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
226
- uint32_t c = (idx + tid) / Br;
227
- uint32_t r = (idx + tid) % Br;
228
- if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
229
- sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
314
+ if (LOGIT_SOFTCAP) {
315
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
316
+ uint32_t c = (idx + tid) / (Br / 4);
317
+ uint32_t r = (idx + tid) % (Br / 4);
318
+ if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
319
+ sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
230
320
  }
231
321
  }
232
322
  barrier();
233
323
  }
234
324
 
235
- if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
236
- bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
237
-
238
- [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
239
- uint32_t c = (idx + tid) % Bc;
240
- uint32_t r = (idx + tid) / Bc;
241
- if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
242
- if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
243
- float f = mask_cache[idx / WorkGroupSize];
244
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
325
+ if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
326
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
327
+ uint32_t c = (idx + tid) / (Br / 4);
328
+ uint32_t r = (idx + tid) % (Br / 4);
329
+ if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
330
+ if (!KV_bounds_check || j * Bc + c < KV) {
331
+ // Mask nem1 bounds check is handled when loading masks
332
+ ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
333
+ ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
334
+ sfsh[c * sfshstride + r] += slopes * masks;
245
335
  }
246
336
  }
247
337
  }
@@ -250,143 +340,237 @@ void main() {
250
340
 
251
341
  float eMf[rows_per_thread];
252
342
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
343
+ const uint r_vec = tile_row(r) / 4;
344
+ const uint r_comp = tile_row(r) % 4;
345
+
253
346
  float rowmaxf = NEG_FLT_MAX_OVER_2;
254
347
  [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
255
348
  if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
256
349
  continue;
257
350
  }
258
- rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
351
+ rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
259
352
  }
260
353
  float Moldf = Mf[r];
261
354
 
355
+ // Compute max across the row
356
+ rowmaxf = subgroupMax(rowmaxf);
357
+
262
358
  // M = max(rowmax, Mold)
263
359
  // P = e^(S - M)
264
360
  // eM = e^(Mold - M)
265
361
  Mf[r] = max(rowmaxf, Moldf);
266
362
  eMf[r] = exp(Moldf - Mf[r]);
363
+
364
+ Lf[r] = eMf[r]*Lf[r];
267
365
  }
268
366
 
269
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
367
+ [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
368
+ const uint d_local = d0 / threads_per_rowgroup;
270
369
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
271
- Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
370
+ Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
272
371
  }
273
372
  }
274
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
275
- Lf[r] = eMf[r]*Lf[r];
276
- }
277
373
 
374
+ // Calculate and store Pf in Psh
278
375
  [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
279
- if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
280
- continue;
281
- }
282
- float Pf[rows_per_thread];
283
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
284
- Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
285
- Lf[r] += Pf[r];
376
+ const uint col = c * cols_per_iter + col_tid;
377
+
378
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
379
+ const uint row = tile_row(r);
380
+ if (KV_bounds_check && j * Bc + col >= KV) {
381
+ Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
382
+ } else {
383
+ const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
384
+ const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
385
+ [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
386
+ Lf[r + vec_idx] += Pf[vec_idx];
387
+ }
388
+ Psh[col * psh_stride + row / 4] = Pf;
389
+ }
286
390
  }
287
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
391
+ }
392
+
393
+ if (SHMEM_STAGING != 0) {
394
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
395
+ uint32_t d = (idx + tid) % (HSV_pad / 4);
396
+ uint32_t c = (idx + tid) / (HSV_pad / 4);
397
+ if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
398
+ f16vec4 V_Tf = f16vec4(0);
399
+ if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
288
400
  #if BLOCK_SIZE > 1
289
- uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
290
- uint ib = coord / BLOCK_SIZE;
291
- uint iqs = (coord % BLOCK_SIZE);
292
- vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
401
+ uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
402
+ uint ib = coord / BLOCK_SIZE;
403
+ uint iqs = (coord % BLOCK_SIZE);
404
+ V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
293
405
  #else
294
- vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
406
+ V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
295
407
  #endif
296
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
297
- Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
408
+ }
409
+
410
+ kvsh[c * kvsh_stride + d] = V_Tf;
298
411
  }
299
412
  }
300
413
  }
301
-
302
414
  barrier();
303
- }
304
415
 
305
- // prevent race on tmpsh
306
- barrier();
416
+ const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
307
417
 
308
- // reduce across threads
418
+ // Each subgroup handles HSV/4 columns
419
+ [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
420
+ const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
309
421
 
310
- float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
311
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
312
- FLOAT_TYPE M = Mf[r];
313
- tmpsh[tid] = M;
314
- // Compute max across the row
315
- barrier();
316
- [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
317
- M = max(M, tmpsh[tid ^ s]);
318
- barrier();
319
- tmpsh[tid] = M;
320
- barrier();
321
- }
322
- rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
323
- barrier();
324
- }
422
+ coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
325
423
 
326
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
327
- Moldf[r] = Mf[r];
424
+ // Preload V tiles for [Bc, 16 * num subgroups]
425
+ const uint v_rows = Bc;
426
+ const uint v_total = v_rows * v_cols;
427
+ const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
328
428
 
329
- // M = max(rowmax, Mold)
330
- // eM = e^(Mold - M)
331
- Mf[r] = max(rowmaxf[r], Moldf[r]);
332
- eMf[r] = exp(Moldf[r] - Mf[r]);
429
+ // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
430
+ // If not, f16 V is loaded directly from global memory if aligned, otherwise
431
+ // staged through a Bc * MatBr size staging buffer.
432
+ // If V is not type f16, then it is always staged for dequantization.
433
+ if (SHMEM_STAGING == 0) {
434
+ #if BLOCK_SIZE == 1
435
+ // For f16, only preload if not aligned
436
+ if (KV_bounds_check) {
437
+ #endif
438
+ [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
439
+ const uint idx = i * gl_WorkGroupSize.x + tid;
440
+ const uint row = idx / v_cols;
441
+ const uint col = idx % v_cols;
333
442
 
334
- Lf[r] = eMf[r]*Lf[r];
335
- }
443
+ const uint v_row = j * Bc + row;
444
+ const uint v_col = hsv_tile * MatBc * row_split + col * 4;
336
445
 
337
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
338
- FLOAT_TYPE L = Lf[r];
339
- tmpsh[tid] = L;
340
- // Compute sum across the row
341
- barrier();
342
- [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
343
- L += tmpsh[tid ^ s];
344
- barrier();
345
- tmpsh[tid] = L;
446
+ const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
447
+ const uint ib = coord / BLOCK_SIZE;
448
+ const uint iqs = coord % BLOCK_SIZE;
449
+
450
+ if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
451
+ #if BLOCK_SIZE > 1
452
+ kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
453
+ #else
454
+ kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
455
+ #endif
456
+ } else {
457
+ kvsh[row * vsh_stride + col] = f16vec4(0.0f);
458
+ }
459
+ }
460
+
461
+ #if BLOCK_SIZE == 1
462
+ }
463
+ #endif
464
+ }
346
465
  barrier();
347
- }
348
- Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
349
- barrier();
350
- }
351
466
 
352
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
353
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
467
+ const uint o_offset = gl_SubgroupID * MatBr / 4;
354
468
 
355
- Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
356
- tmpshv4[tid] = Of[r][d];
469
+ if (hsv_offset < HSV_pad) {
470
+ [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
471
+ coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
357
472
 
358
- barrier();
359
- [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
360
- Of[r][d] += tmpshv4[tid ^ s];
361
- barrier();
362
- tmpshv4[tid] = Of[r][d];
363
- barrier();
473
+ if (SHMEM_STAGING == 0) {
474
+ #if BLOCK_SIZE == 1
475
+ if (!KV_bounds_check) {
476
+ // F16 values can be loaded directly from global memory
477
+ const uint v_tile_row = j * Bc + bc_chunk * MatBc;
478
+ const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
479
+ coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
480
+ } else
481
+ #endif
482
+ {
483
+ const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
484
+ coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
485
+ }
486
+ } else {
487
+ const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
488
+ coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
489
+ }
490
+
491
+ PVMat = coopMatMulAdd(KMat, QMat, PVMat);
492
+ }
493
+
494
+ // Store PVMat to pvsh and load into Of
495
+ coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
364
496
  }
365
- Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
497
+
366
498
  barrier();
499
+
500
+ const uint hsv_per_tile = row_split * MatBc;
501
+ const uint hsv_base = hsv_tile * hsv_per_tile;
502
+ const uint d_values_per_tile = hsv_per_tile / 4;
503
+
504
+ const uint d_start = hsv_tile * d_values_per_tile;
505
+ const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
506
+
507
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
508
+ const uint row = tile_row(r);
509
+
510
+ [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
511
+ const uint d = d_local * threads_per_rowgroup + col_tid;
512
+ const uint hsv_col = 4 * d;
513
+
514
+ if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
515
+ const uint local_hsv = (hsv_col - hsv_base) / 4;
516
+ Of[r][d_local] += pvsh[row * osh_stride + local_hsv];
517
+ }
518
+ }
519
+ }
367
520
  }
521
+
522
+ barrier();
523
+ }
524
+
525
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
526
+ Lf[r] = subgroupAdd(Lf[r]);
368
527
  }
369
528
 
370
529
  // If there is split_k, then the split_k resolve shader does the final
371
530
  // division by L. Store the intermediate O value and per-row m and L values.
372
531
  if (p.k_num > 1) {
373
- uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
532
+ if (p.gqa_ratio > 1) {
533
+ // note: O and Q have swapped coord 1,2.
534
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
374
535
 
375
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
376
- if (tile_row(r) < N) {
377
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
378
- [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
379
- perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
536
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
537
+ if (tile_row(r) < N) {
538
+ [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
539
+ const uint d = d0 + col_tid;
540
+ if (d >= HSV/4) break;
541
+ const uint d_local = d0 / threads_per_rowgroup;
542
+ gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
380
543
  }
381
544
  }
382
545
  }
383
- }
384
546
 
385
- o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
386
- [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
387
- if (tile_row(r) < N) {
388
- perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
389
- perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
547
+ o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
548
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
549
+ if (tile_row(r) < N) {
550
+ perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
551
+ perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
552
+ }
553
+ }
554
+ } else {
555
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
556
+ const uint row = tile_row(r);
557
+ const uint global_row = i * Br + row;
558
+
559
+ if (global_row < N) {
560
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
561
+
562
+ [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
563
+ const uint d = d0 + col_tid;
564
+ if (d >= HSV/4) break;
565
+ data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
566
+ }
567
+ }
568
+
569
+ if (global_row < N && col_tid == 0) {
570
+ uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
571
+ data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
572
+ data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
573
+ }
390
574
  }
391
575
  }
392
576
 
@@ -403,8 +587,9 @@ void main() {
403
587
  if (sink > Mf[r]) {
404
588
  ms = exp(Mf[r] - sink);
405
589
 
406
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
407
- Of[r][d] *= ACC_TYPE(ms);
590
+ [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
591
+ const uint d_local = d0 / threads_per_rowgroup;
592
+ Of[r][d_local] *= float16_t(ms);
408
593
  }
409
594
  } else {
410
595
  vs = exp(sink - Mf[r]);
@@ -419,34 +604,37 @@ void main() {
419
604
  Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
420
605
  }
421
606
 
422
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
607
+ [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
608
+ const uint d_local = d0 / threads_per_rowgroup;
423
609
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
424
- Of[r][d] *= ACC_TYPE(Lfrcp[r]);
425
- #if defined(ACC_TYPE_MAX)
426
- Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
610
+ Of[r][d_local] *= float16_t(Lfrcp[r]);
611
+ #if defined(FLOAT_TYPE_MAX)
612
+ Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
427
613
  #endif
428
614
  }
429
615
  }
430
616
 
431
- uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
617
+ uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
432
618
 
433
619
  if (p.gqa_ratio > 1) {
434
620
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
435
621
  if (tile_row(r) < N) {
436
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
437
- [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
438
- perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
439
- }
622
+ [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
623
+ const uint d = d0 + col_tid;
624
+ if (d >= HSV / 4) break;
625
+ const uint d_local = d0 / threads_per_rowgroup;
626
+ gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
440
627
  }
441
628
  }
442
629
  }
443
630
  } else {
444
631
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
445
632
  if (i * Br + tile_row(r) < N) {
446
- [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
447
- [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
448
- data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
449
- }
633
+ [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
634
+ const uint d = d0 + col_tid;
635
+ if (d >= HSV / 4) break;
636
+ const uint d_local = d0 / threads_per_rowgroup;
637
+ data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
450
638
  }
451
639
  }
452
640
  }