whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -18,6 +18,10 @@
18
18
 
19
19
  #include "common.cuh"
20
20
 
21
+ // On Volta each warp is doing 4 8x8 mma operations in parallel.
22
+ // The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23
+ // However, the i indices in this file are by default permuted to simplify the index calculations.
24
+ // #define GGML_CUDA_MMA_NO_VOLTA_PERM
21
25
 
22
26
  #if CUDART_VERSION >= 11080
23
27
 
@@ -64,15 +68,59 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
64
68
 
65
69
  namespace ggml_cuda_mma {
66
70
 
71
+ // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
72
+ // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
73
+ // In those cases the data can be split in different ways across the warp.
74
+ enum data_layout {
75
+ // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
76
+ // For the A/C matrices this means I major == row major, J major == column major.
77
+ // For the B matrix this means I major == column major, J major == row major.
78
+ // MIRRORED == Each data value is held exactly once per thread subgroup.
79
+ DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
80
+ DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81
+ DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
82
+ DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
83
+ };
84
+ // Implemented mma combinations are:
85
+ // - (I_MAJOR, I_MAJOR) -> I_MAJOR
86
+ // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
87
+ // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
88
+
89
+ static constexpr bool is_i_major(const data_layout dl) {
90
+ return dl == DATA_LAYOUT_I_MAJOR ||
91
+ dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
92
+ }
93
+
94
+ static constexpr __device__ data_layout get_input_data_layout() {
95
+ #if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
96
+ return DATA_LAYOUT_I_MAJOR_MIRRORED;
97
+ #else
98
+ return DATA_LAYOUT_I_MAJOR;
99
+ #endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
100
+ }
101
+
102
+ template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
103
+ struct tile {};
104
+
67
105
  template <int I_, int J_, typename T>
68
- struct tile {
69
- static constexpr int I = I_;
70
- static constexpr int J = J_;
106
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
107
+ static constexpr int I = I_;
108
+ static constexpr int J = J_;
109
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
71
110
 
72
- #if defined(GGML_USE_HIP)
111
+ #if defined(AMD_MFMA_AVAILABLE)
73
112
  static constexpr int ne = I * J / 64;
74
113
  T x[ne] = {0};
75
114
 
115
+ static constexpr __device__ bool supported() {
116
+ if (I == 64 && J == 2) return true;
117
+ if (I == 16 && J == 8) return true;
118
+ if (I == 32 && J == 4) return true;
119
+ if (I == 16 && J == 16) return true;
120
+ if (I == 32 && J == 32) return true;
121
+ return false;
122
+ }
123
+
76
124
  static __device__ __forceinline__ int get_i(const int l) {
77
125
  if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
78
126
  return threadIdx.x % 16;
@@ -81,11 +129,12 @@ namespace ggml_cuda_mma {
81
129
  } else if constexpr (I == 32 && J == 4) {
82
130
  return threadIdx.x % 32;
83
131
  } else if constexpr (I == 16 && J == 16) {
84
- return 4 * (threadIdx.x / 16) + l;
132
+ return threadIdx.x % 16;
85
133
  } else if constexpr (I == 32 && J == 32) {
86
- return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
134
+ return threadIdx.x % 32;
87
135
  } else {
88
- static_assert(I == -1 && J == -1, "template specialization not implemented");
136
+ NO_DEVICE_CODE;
137
+ return -1;
89
138
  }
90
139
  }
91
140
 
@@ -97,26 +146,115 @@ namespace ggml_cuda_mma {
97
146
  } else if constexpr (I == 32 && J == 4) {
98
147
  return 2 * (threadIdx.x / 32) + l;
99
148
  } else if constexpr (I == 16 && J == 16) {
100
- return threadIdx.x % 16;
149
+ return 4 * (threadIdx.x / 16) + l;
101
150
  } else if constexpr (I == 32 && J == 32) {
102
- return threadIdx.x % 32;
151
+ return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
103
152
  } else {
104
- static_assert(I == -1 && J == -1, "template specialization not implemented");
153
+ NO_DEVICE_CODE;
154
+ return -1;
105
155
  }
106
156
  }
157
+ #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
158
+ static constexpr int ne = I * J / 32;
159
+ T x[ne] = {0};
160
+
161
+ static constexpr __device__ bool supported() {
162
+ if (I == 32 && J == 8) return true;
163
+ return false;
164
+ }
165
+
166
+ static __device__ __forceinline__ int get_i(const int l) {
167
+ if constexpr (I == 32 && J == 8) {
168
+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
169
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
107
170
  #else
171
+ return (l & 2) + (threadIdx.x & ~2);
172
+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
173
+ } else {
174
+ NO_DEVICE_CODE;
175
+ return -1;
176
+ }
177
+ }
178
+
179
+ static __device__ __forceinline__ int get_j(const int l) {
180
+ if constexpr (I == 32 && J == 8) {
181
+ return (threadIdx.x & 2) + (l & (4 + 1));
182
+ } else {
183
+ NO_DEVICE_CODE;
184
+ return -1;
185
+ }
186
+ }
187
+ #elif defined(AMD_WMMA_AVAILABLE)
108
188
  static constexpr int ne = I * J / 32;
109
189
  T x[ne] = {0};
110
190
 
191
+ static constexpr __device__ bool supported() {
192
+ if (I == 16 && J == 16) return true;
193
+ if (I == 16 && J == 8) return true;
194
+ if (I == 16 && J == 4) return true;
195
+ return false;
196
+ }
197
+
111
198
  static __device__ __forceinline__ int get_i(const int l) {
112
- if constexpr (I == 8 && (J == 4 || J == 8)) {
199
+ if constexpr (supported()) {
200
+ return threadIdx.x % 16;
201
+ } else {
202
+ NO_DEVICE_CODE;
203
+ return -1;
204
+ }
205
+ }
206
+
207
+ static __device__ __forceinline__ int get_j(const int l) {
208
+ if constexpr (I == 16 && J == 16) {
209
+ #if defined(RDNA3)
210
+ if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {
211
+ // matrix C
212
+ return 2 * l + (threadIdx.x / 16);
213
+ } else {
214
+ // matrix A&B
215
+ return l;
216
+ }
217
+ #else
218
+ // matrix C is the transposed matrix A&B on RDNA4
219
+ return ne * (threadIdx.x / 16) + l;
220
+ #endif // defined(RDNA3)
221
+ } else if constexpr (I == 16 && J == 8) {
222
+ // mmq input for RDNA4
223
+ return ne * (threadIdx.x / 16) + l;
224
+ } else if constexpr (I == 16 && J == 4) {
225
+ return ne * (threadIdx.x / 16) + l;
226
+ } else {
227
+ NO_DEVICE_CODE;
228
+ return -1;
229
+ }
230
+ }
231
+ #else
232
+ static constexpr int ne = I * J / 32;
233
+ T x[ne] = {0};
234
+
235
+ static constexpr __device__ bool supported() {
236
+ if (I == 8 && J == 4) return true;
237
+ if (I == 8 && J == 8) return true;
238
+ if (I == 16 && J == 8) return true;
239
+ if (I == 16 && J == 16) return true;
240
+ if (I == 32 && J == 8) return true;
241
+ return false;
242
+ }
243
+
244
+ static __device__ __forceinline__ int get_i(const int l) {
245
+ if constexpr (I == 8 && J == 4) {
246
+ return threadIdx.x / 4;
247
+ } else if constexpr (I == 8 && J == 8) {
113
248
  return threadIdx.x / 4;
114
249
  } else if constexpr (I == 16 && J == 8) {
115
- return (l / 2) * 8 + threadIdx.x / 4;
250
+ return ((l / 2) * 8) + (threadIdx.x / 4);
116
251
  } else if constexpr (I == 16 && J == 16) {
117
- return ((l / 2) % 2) * 8 + threadIdx.x / 4;
252
+ return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
253
+ } else if constexpr (I == 32 && J == 8) {
254
+ return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
118
255
  } else {
119
- static_assert(I == -1 && J == -1, "template specialization not implemented");
256
+ NO_DEVICE_CODE;
257
+ return -1;
120
258
  }
121
259
  }
122
260
 
@@ -124,82 +262,395 @@ namespace ggml_cuda_mma {
124
262
  if constexpr (I == 8 && J == 4) {
125
263
  return threadIdx.x % 4;
126
264
  } else if constexpr (I == 8 && J == 8) {
127
- return 4 * l + threadIdx.x % 4;
265
+ return (l * 4) + (threadIdx.x % 4);
128
266
  } else if constexpr (I == 16 && J == 8) {
129
- return 2 * (threadIdx.x % 4) + l % 2;
267
+ return ((threadIdx.x % 4) * 2) + (l % 2);
130
268
  } else if constexpr (I == 16 && J == 16) {
131
- return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
269
+ return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
270
+ } else if constexpr (I == 32 && J == 8) {
271
+ return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
132
272
  } else {
133
- static_assert(I == -1 && J == -1, "template specialization not implemented");
273
+ NO_DEVICE_CODE;
274
+ return -1;
134
275
  }
135
276
  }
136
277
  #endif // defined(GGML_USE_HIP)
137
278
  };
138
279
 
139
280
  template <int I_, int J_>
140
- struct tile<I_, J_, half2> {
141
- static constexpr int I = I_;
142
- static constexpr int J = J_;
281
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
282
+ static constexpr int I = I_;
283
+ static constexpr int J = J_;
284
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
285
+
286
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
143
287
  static constexpr int ne = I * J / WARP_SIZE;
144
288
  half2 x[ne] = {{0.0f, 0.0f}};
145
289
 
290
+ static constexpr __device__ bool supported() {
291
+ if (I == 32 && J == 4) return true;
292
+ return false;
293
+ }
294
+
295
+ static __device__ __forceinline__ int get_i(const int l) {
296
+ if constexpr (I == 32 && J == 4) {
297
+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
298
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
299
+ #else
300
+ return threadIdx.x;
301
+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
302
+ } else {
303
+ NO_DEVICE_CODE;
304
+ return -1;
305
+ }
306
+ }
307
+
308
+ static __device__ __forceinline__ int get_j(const int l) {
309
+ if constexpr (I == 32 && J == 4) {
310
+ return l;
311
+ } else {
312
+ NO_DEVICE_CODE;
313
+ return -1;
314
+ }
315
+ }
316
+ #elif defined(AMD_WMMA_AVAILABLE)
317
+ static constexpr int ne = I * J / 32;
318
+ half2 x[ne] = {{0.0f, 0.0f}};
319
+
320
+ static constexpr __device__ bool supported() {
321
+ if (I == 16 && J == 8) return true;
322
+ return false;
323
+ }
324
+
325
+ static __device__ __forceinline__ int get_i(const int l) {
326
+ if constexpr (I == 16 && J == 8) {
327
+ return threadIdx.x % 16;
328
+ } else {
329
+ NO_DEVICE_CODE;
330
+ return -1;
331
+ }
332
+ }
333
+
334
+ static __device__ __forceinline__ int get_j(const int l) {
335
+ if constexpr (I == 16 && J == 8) {
336
+ return ne * (threadIdx.x / 16) + l;
337
+ } else {
338
+ NO_DEVICE_CODE;
339
+ return -1;
340
+ }
341
+ }
342
+ #elif defined(AMD_MFMA_AVAILABLE)
343
+ static constexpr int ne = I * J / 64;
344
+ half2 x[ne] = {{0.0f, 0.0f}};
345
+
346
+ static constexpr __device__ bool supported() {
347
+ if (I == 16 && J == 8) return true;
348
+ return false;
349
+ }
350
+
351
+ static __device__ __forceinline__ int get_i(const int l) {
352
+ if constexpr (I == 16 && J == 8) {
353
+ return threadIdx.x % 16;
354
+ } else {
355
+ NO_DEVICE_CODE;
356
+ return -1;
357
+ }
358
+ }
359
+
360
+ static __device__ __forceinline__ int get_j(const int l) {
361
+ if constexpr (I == 16 && J == 8) {
362
+ return ne * (threadIdx.x / 16) + l;
363
+ } else {
364
+ NO_DEVICE_CODE;
365
+ return -1;
366
+ }
367
+ }
368
+ #else
369
+ static constexpr int ne = I * J / WARP_SIZE;
370
+ half2 x[ne] = {{0.0f, 0.0f}};
371
+
372
+ static constexpr __device__ bool supported() {
373
+ if (I == 8 && J == 4) return true;
374
+ if (I == 8 && J == 8) return true;
375
+ if (I == 16 && J == 8) return true;
376
+ if (I == 16 && J == 16) return true;
377
+ if (I == 32 && J == 8) return true;
378
+ return false;
379
+ }
380
+
146
381
  static __device__ __forceinline__ int get_i(const int l) {
147
382
  if constexpr (I == 8 && J == 8) {
148
383
  return threadIdx.x / 4;
149
384
  } else if constexpr (I == 16 && J == 4) {
150
- return l * 8 + threadIdx.x / 4;
385
+ return (l * 8) + (threadIdx.x / 4);
151
386
  } else if constexpr (I == 16 && J == 8) {
152
- return (l % 2) * 8 + threadIdx.x / 4;
387
+ return ((l % 2) * 8) + (threadIdx.x / 4);
388
+ } else if constexpr (I == 32 && J == 8) {
389
+ return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
153
390
  } else {
154
- static_assert(I == -1 && J == -1, "template specialization not implemented");
391
+ NO_DEVICE_CODE;
392
+ return -1;
155
393
  }
156
394
  }
157
395
 
158
396
  static __device__ __forceinline__ int get_j(const int l) {
159
397
  if constexpr (I == 8 && J == 8) {
160
- return l * 4 + threadIdx.x % 4;
398
+ return (l * 4) + (threadIdx.x % 4);
161
399
  } else if constexpr (I == 16 && J == 4) {
162
400
  return threadIdx.x % 4;
163
401
  } else if constexpr (I == 16 && J == 8) {
164
- return (l / 2) * 4 + threadIdx.x % 4;
402
+ return ((l / 2) * 4) + (threadIdx.x % 4);
403
+ } else if constexpr (I == 32 && J == 8) {
404
+ return ((l & 2) * 2) + (threadIdx.x % 4);
165
405
  } else {
166
- static_assert(I == -1 && J == -1, "template specialization not implemented");
406
+ NO_DEVICE_CODE;
407
+ return -1;
167
408
  }
168
409
  }
410
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
169
411
  };
170
412
 
171
413
  template <int I_, int J_>
172
- struct tile<I_, J_, nv_bfloat162> {
173
- static constexpr int I = I_;
174
- static constexpr int J = J_;
414
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
415
+ static constexpr int I = I_;
416
+ static constexpr int J = J_;
417
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
418
+
419
+ #if defined(AMD_WMMA_AVAILABLE)
420
+ static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
421
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
422
+
423
+ static constexpr __device__ bool supported() {
424
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
425
+ }
426
+
427
+ static __device__ __forceinline__ int get_i(const int l) {
428
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
429
+ }
430
+
431
+ static __device__ __forceinline__ int get_j(const int l) {
432
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
433
+ }
434
+ #elif defined(AMD_MFMA_AVAILABLE)
435
+ static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
436
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
437
+
438
+ static constexpr __device__ bool supported() {
439
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
440
+ }
441
+
442
+ static __device__ __forceinline__ int get_i(const int l) {
443
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
444
+ }
445
+
446
+ static __device__ __forceinline__ int get_j(const int l) {
447
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
448
+ }
449
+ #else
175
450
  static constexpr int ne = I * J / WARP_SIZE;
176
451
  nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
177
452
 
453
+ static constexpr __device__ bool supported() {
454
+ if (I == 8 && J == 8) return true;
455
+ if (I == 16 && J == 4) return true;
456
+ if (I == 16 && J == 8) return true;
457
+ return false;
458
+ }
459
+
178
460
  static __device__ __forceinline__ int get_i(const int l) {
179
461
  if constexpr (I == 8 && J == 8) {
180
462
  return threadIdx.x / 4;
181
463
  } else if constexpr (I == 16 && J == 4) {
182
- return l * 8 + threadIdx.x / 4;
464
+ return (l * 8) + (threadIdx.x / 4);
183
465
  } else if constexpr (I == 16 && J == 8) {
184
- return (l % 2) * 8 + threadIdx.x / 4;
466
+ return ((l % 2) * 8) + (threadIdx.x / 4);
185
467
  } else {
186
- static_assert(I == -1 && J == -1, "template specialization not implemented");
468
+ NO_DEVICE_CODE;
469
+ return -1;
187
470
  }
188
471
  }
189
472
 
190
473
  static __device__ __forceinline__ int get_j(const int l) {
191
474
  if constexpr (I == 8 && J == 8) {
192
- return l * 4 + threadIdx.x % 4;
475
+ return (l * 4) + (threadIdx.x % 4);
193
476
  } else if constexpr (I == 16 && J == 4) {
194
477
  return threadIdx.x % 4;
195
478
  } else if constexpr (I == 16 && J == 8) {
196
- return (l / 2) * 4 + threadIdx.x % 4;
479
+ return ((l / 2) * 4) + (threadIdx.x % 4);
480
+ } else {
481
+ NO_DEVICE_CODE;
482
+ return -1;
483
+ }
484
+ }
485
+ #endif // defined(AMD_WMMA_AVAILABLE)
486
+ };
487
+
488
+ template <int I_, int J_, typename T>
489
+ struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
490
+ static constexpr int I = I_;
491
+ static constexpr int J = J_;
492
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
493
+
494
+ static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
495
+ T x[ne] = {0};
496
+
497
+ static constexpr __device__ bool supported() {
498
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
499
+ }
500
+
501
+ static __device__ __forceinline__ int get_i(const int l) {
502
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
503
+ }
504
+
505
+ static __device__ __forceinline__ int get_j(const int l) {
506
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
507
+ }
508
+ };
509
+
510
+ template <int I_, int J_, typename T>
511
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
512
+ static constexpr int I = I_;
513
+ static constexpr int J = J_;
514
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
515
+
516
+ // RDNA3
517
+ static constexpr int ne = I * J / 32 * 2;
518
+
519
+ T x[ne] = {0};
520
+
521
+ static constexpr __device__ bool supported() {
522
+ if (I == 16 && J == 16) return true;
523
+ if (I == 16 && J == 8) return true;
524
+ if (I == 16 && J == 4) return true;
525
+ return false;
526
+ }
527
+
528
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
529
+ if constexpr (supported()) {
530
+ return threadIdx.x % 16;
197
531
  } else {
198
- static_assert(I == -1 && J == -1, "template specialization not implemented");
532
+ NO_DEVICE_CODE;
533
+ return -1;
534
+ }
535
+ }
536
+
537
+ static __device__ __forceinline__ int get_j(const int l) {
538
+ if constexpr (supported()) {
539
+ return l;
540
+ } else {
541
+ NO_DEVICE_CODE;
542
+ return -1;
199
543
  }
200
544
  }
201
545
  };
202
546
 
547
+ template <int I_, int J_>
548
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
549
+ static constexpr int I = I_;
550
+ static constexpr int J = J_;
551
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
552
+ #if defined(RDNA3)
553
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
554
+
555
+ half2 x[ne] = {{0.0f, 0.0f}};
556
+
557
+ static constexpr __device__ bool supported() {
558
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
559
+ }
560
+
561
+ static __device__ __forceinline__ int get_i(const int l) {
562
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
563
+ }
564
+
565
+ static __device__ __forceinline__ int get_j(const int l) {
566
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
567
+ }
568
+ #else // Volta
569
+ static constexpr int ne = I * J / (WARP_SIZE/4);
570
+
571
+ half2 x[ne] = {{0.0f, 0.0f}};
572
+
573
+ static constexpr __device__ bool supported() {
574
+ if (I == 8 && J == 4) return true;
575
+ return false;
576
+ }
577
+
578
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
579
+ if constexpr (I == 8 && J == 4) {
580
+ return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
581
+ } else {
582
+ NO_DEVICE_CODE;
583
+ return -1;
584
+ }
585
+ }
586
+
587
+ static __device__ __forceinline__ int get_j(const int l) {
588
+ if constexpr (I == 8 && J == 4) {
589
+ return l;
590
+ } else {
591
+ NO_DEVICE_CODE;
592
+ return -1;
593
+ }
594
+ }
595
+ #endif // defined(RDNA3)
596
+ };
597
+
598
+ template <int I_, int J_>
599
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
600
+ static constexpr int I = I_;
601
+ static constexpr int J = J_;
602
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
603
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
604
+
605
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
606
+
607
+ static constexpr __device__ bool supported() {
608
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
609
+ }
610
+
611
+ static __device__ __forceinline__ int get_i(const int l) {
612
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
613
+ }
614
+
615
+ static __device__ __forceinline__ int get_j(const int l) {
616
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
617
+ }
618
+ };
619
+
620
+ template <int I_, int J_>
621
+ struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
622
+ static constexpr int I = I_;
623
+ static constexpr int J = J_;
624
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
625
+ static constexpr int ne = I * J / (WARP_SIZE/4);
626
+
627
+ half2 x[ne] = {{0.0f, 0.0f}};
628
+
629
+ static constexpr __device__ bool supported() {
630
+ if (I == 8 && J == 4) return true;
631
+ return false;
632
+ }
633
+
634
+ static __device__ __forceinline__ int get_i(const int l) {
635
+ if constexpr (I == 8 && J == 4) {
636
+ return ((l / 2) * 4) + (threadIdx.x % 4);
637
+ } else {
638
+ NO_DEVICE_CODE;
639
+ return -1;
640
+ }
641
+ }
642
+
643
+ static __device__ __forceinline__ int get_j(const int l) {
644
+ if constexpr (I == 8 && J == 4) {
645
+ return ((threadIdx.x / 16) * 2) + (l % 2);
646
+ } else {
647
+ NO_DEVICE_CODE;
648
+ return -1;
649
+ }
650
+ }
651
+ };
652
+
653
+ #if defined(TURING_MMA_AVAILABLE)
203
654
  template <int I, int J>
204
655
  static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
205
656
  tile<I, J/2, half2> ret;
@@ -217,9 +668,54 @@ namespace ggml_cuda_mma {
217
668
 
218
669
  return ret;
219
670
  }
671
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
672
+ template <int I, int J>
673
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
674
+ tile<I, J/2, half2> ret;
675
+ #pragma unroll
676
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
677
+ ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
678
+ }
679
+ return ret;
680
+ }
681
+
682
+ static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
683
+ NO_DEVICE_CODE;
684
+ return tile<8, 8, half2>{};
685
+ }
686
+ #else // Volta
687
+ template <int I, int J>
688
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
689
+ tile<I, J/2, half2> ret;
690
+ #pragma unroll
691
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
692
+ ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
693
+ ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
694
+
695
+ // On Volta FP16 and FP32 tiles have a different memory layout,
696
+ // for the conversion threads with an offset of 2 need to exchange half their values:
697
+ ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
698
+ 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
699
+ }
700
+ return ret;
701
+ }
702
+ #endif // defined(TURING_MMA_AVAILABLE)
220
703
 
221
- template <int I, int J, typename T>
222
- static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
704
+ static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
705
+ #if defined(RDNA4)
706
+ const int row = t.get_i(0);
707
+ const int left_right = t.get_j(0) / 4;
708
+ const int up_down = row / 8;
709
+ const int idx = row % 8;
710
+ reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
711
+ #else
712
+ GGML_UNUSED_VARS(t);
713
+ NO_DEVICE_CODE;
714
+ #endif // defined(RDNA4)
715
+ }
716
+
717
+ template <int I, int J, typename T, data_layout dl>
718
+ static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
223
719
  #if defined(AMD_MFMA_AVAILABLE)
224
720
  if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
225
721
  #pragma unroll
@@ -227,9 +723,28 @@ namespace ggml_cuda_mma {
227
723
  t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
228
724
  }
229
725
  } else {
230
- int64_t * xi = (int64_t *) t.x;
231
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
232
- xi[0] = xs[0];
726
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
727
+ }
728
+ #elif defined(AMD_WMMA_AVAILABLE)
729
+ // All wmma layout has contiguous data when i-major.
730
+ if constexpr (is_i_major(dl)) {
731
+ // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
732
+ constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
733
+ if constexpr (sizeof(t.x) > aligned_copy_bytes) {
734
+ static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
735
+ constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
736
+ #pragma unroll
737
+ for (int i = 0; i < aligned_copy_count; ++i) {
738
+ ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
739
+ }
740
+ } else {
741
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
742
+ }
743
+ } else {
744
+ #pragma unroll
745
+ for (int l = 0; l < t.ne; ++l) {
746
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
747
+ }
233
748
  }
234
749
  #else
235
750
  #pragma unroll
@@ -263,25 +778,63 @@ namespace ggml_cuda_mma {
263
778
  : "=r"(xi[0]), "=r"(xi[1])
264
779
  : "l"(xs));
265
780
  #else
266
- load_generic(xs0, stride);
267
- GGML_UNUSED(t);
781
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
782
+ GGML_UNUSED_VARS(t, xs0, stride);
783
+ NO_DEVICE_CODE;
784
+ #else
785
+ load_generic(t, xs0, stride);
786
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
268
787
  #endif // TURING_MMA_AVAILABLE
269
788
  }
270
789
 
271
- template <typename T>
790
+ template <typename T, data_layout dl>
272
791
  static __device__ __forceinline__ void load_ldmatrix(
273
- tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
792
+ tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
274
793
  #if defined(TURING_MMA_AVAILABLE)
275
794
  int * xi = (int * ) t.x;
276
795
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
277
796
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
278
797
  : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
279
798
  : "l"(xs));
799
+ #else
800
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
801
+ #if 1
802
+ // TODO: more generic handling
803
+ static_assert(sizeof(T) == 4, "bad type size");
804
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
805
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
806
+ #else
807
+ load_generic(t, xs0, stride);
808
+ #endif // 1
280
809
  #else
281
810
  load_generic(t, xs0, stride);
811
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282
812
  #endif // TURING_MMA_AVAILABLE
283
813
  }
284
814
 
815
+ static __device__ __forceinline__ void load_ldmatrix(
816
+ tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
817
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
818
+ }
819
+
820
+ static __device__ __forceinline__ void load_ldmatrix(
821
+ tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
822
+ #pragma unroll
823
+ for (int l0 = 0; l0 < t.ne; l0 += 2) {
824
+ ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
825
+ }
826
+ }
827
+
828
+ static __device__ __forceinline__ void load_ldmatrix(
829
+ tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
830
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
831
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
832
+ #else
833
+ GGML_UNUSED_VARS(t, xs0, stride);
834
+ NO_DEVICE_CODE;
835
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
836
+ }
837
+
285
838
  template <typename T>
286
839
  static __device__ __forceinline__ void load_ldmatrix_trans(
287
840
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
@@ -400,14 +953,54 @@ namespace ggml_cuda_mma {
400
953
  : "+r"(Dxi[2]), "+r"(Dxi[3])
401
954
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
402
955
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
956
+ #elif defined(AMD_WMMA_AVAILABLE)
957
+ #if defined(RDNA4)
958
+ using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
959
+ halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);
960
+ const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
961
+ const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
962
+ acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
963
+ #else
964
+ GGML_UNUSED_VARS(D, A, B);
965
+ NO_DEVICE_CODE;
966
+ #endif // defined(RDNA4)
967
+ #elif defined(AMD_MFMA_AVAILABLE)
968
+ // MFMA: FP16 input, FP32 accumulate, convert back to half2.
969
+ using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
970
+ using floatx4_t = __attribute__((ext_vector_type(4))) float;
971
+
972
+ // Convert existing half2 accumulator to float for MFMA:
973
+ floatx4_t acc_f32;
974
+ {
975
+ const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]);
976
+ #pragma unroll
977
+ for (int i = 0; i < 4; ++i) {
978
+ acc_f32[i] = (float)acc_h[i];
979
+ }
980
+ }
981
+
982
+ const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
983
+ const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
984
+ acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
985
+
986
+ // Convert back to half2:
987
+ {
988
+ halfx4_t result_h;
989
+ #pragma unroll
990
+ for (int i = 0; i < 4; ++i) {
991
+ result_h[i] = (_Float16)acc_f32[i];
992
+ }
993
+ reinterpret_cast<halfx4_t&>(D.x[0]) = result_h;
994
+ }
403
995
  #else
404
996
  GGML_UNUSED_VARS(D, A, B);
405
997
  NO_DEVICE_CODE;
406
998
  #endif // TURING_MMA_AVAILABLE
407
999
  }
408
1000
 
1001
+ template <data_layout dl_ab, data_layout dl_d>
409
1002
  static __device__ __forceinline__ void mma(
410
- tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
1003
+ tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
411
1004
  #ifdef AMPERE_MMA_AVAILABLE
412
1005
  const int * Axi = (const int *) A.x;
413
1006
  const int * Bxi = (const int *) B.x;
@@ -421,6 +1014,53 @@ namespace ggml_cuda_mma {
421
1014
  #endif // AMPERE_MMA_AVAILABLE
422
1015
  }
423
1016
 
1017
+ template <data_layout dl_ab, data_layout dl_d>
1018
+ static __device__ __forceinline__ void mma(
1019
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
1020
+ #ifdef AMD_MFMA_AVAILABLE
1021
+ using floatx4_t = __attribute__((ext_vector_type(4))) float;
1022
+ floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1023
+ #if defined(CDNA3)
1024
+ using floatx2_t = __attribute__((ext_vector_type(2))) float;
1025
+ const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
1026
+ const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
1027
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
1028
+ #elif defined(CDNA2) || defined(CDNA1)
1029
+ #pragma unroll
1030
+ for (int i = 0; i < 2; ++i) {
1031
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
1032
+ }
1033
+ #else
1034
+ GGML_UNUSED_VARS(D, A, B);
1035
+ NO_DEVICE_CODE;
1036
+ #endif // defined(CDNA3)
1037
+ #else
1038
+ GGML_UNUSED_VARS(D, A, B);
1039
+ NO_DEVICE_CODE;
1040
+ #endif // AMD_MFMA_AVAILABLE
1041
+ }
1042
+
1043
+ static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
1044
+ const tile<16, 8, int> & A,
1045
+ const tile<8, 8, int> & B,
1046
+ uint32_t a_scale,
1047
+ uint32_t b_scale) {
1048
+ #ifdef BLACKWELL_MMA_AVAILABLE
1049
+ const int * Axi = (const int *) A.x;
1050
+ const int * Bxi = (const int *) B.x;
1051
+ float * Dxi = (float *) D.x;
1052
+
1053
+ asm volatile(
1054
+ "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
1055
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1056
+ "%10, {0, 0}, %11, {0, 0};"
1057
+ : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
1058
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
1059
+ #else
1060
+ GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
1061
+ #endif // BLACKWELL_MMA_AVAILABLE
1062
+ }
1063
+
424
1064
  static __device__ __forceinline__ void mma(
425
1065
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
426
1066
  #ifdef TURING_MMA_AVAILABLE
@@ -461,8 +1101,9 @@ namespace ggml_cuda_mma {
461
1101
  #endif // AMPERE_MMA_AVAILABLE
462
1102
  }
463
1103
 
1104
+ template <data_layout dl_ab, data_layout dl_d>
464
1105
  static __device__ __forceinline__ void mma(
465
- tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
1106
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
466
1107
  #ifdef TURING_MMA_AVAILABLE
467
1108
  const int * Axi = (const int *) A.x;
468
1109
  const int * Bxi = (const int *) B.x;
@@ -489,14 +1130,89 @@ namespace ggml_cuda_mma {
489
1130
  : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
490
1131
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
491
1132
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
1133
+ #elif defined(AMD_WMMA_AVAILABLE)
1134
+ #if defined(RDNA4)
1135
+ using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
1136
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1137
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1138
+ const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
1139
+ const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
1140
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
1141
+ #elif defined(RDNA3)
1142
+ using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
1143
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1144
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1145
+ const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
1146
+ const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
1147
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
1148
+ #else
1149
+ GGML_UNUSED_VARS(D, A, B);
1150
+ NO_DEVICE_CODE;
1151
+ #endif // RDNA4
1152
+ #elif defined(AMD_MFMA_AVAILABLE)
1153
+ using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
1154
+ using floatx4_t = __attribute__((ext_vector_type(4))) float;
1155
+ floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1156
+ const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
1157
+ const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
1158
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
492
1159
  #else
493
1160
  GGML_UNUSED_VARS(D, A, B);
494
1161
  NO_DEVICE_CODE;
495
1162
  #endif // TURING_MMA_AVAILABLE
496
1163
  }
497
1164
 
1165
+ template <data_layout dl_ab, data_layout dl_d>
1166
+ static __device__ __forceinline__ void mma(
1167
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
1168
+ #if defined(AMD_WMMA_AVAILABLE)
1169
+ #if defined(RDNA4)
1170
+ using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
1171
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1172
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1173
+ const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
1174
+ const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
1175
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
1176
+ #elif defined(RDNA3)
1177
+ using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
1178
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1179
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1180
+ const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
1181
+ const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
1182
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
1183
+ #else
1184
+ GGML_UNUSED_VARS(D, A, B);
1185
+ NO_DEVICE_CODE;
1186
+ #endif // defined(RDNA4)
1187
+ #elif defined(AMD_MFMA_AVAILABLE)
1188
+ using floatx4_t = __attribute__((ext_vector_type(4))) float;
1189
+ floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1190
+ #if defined(CDNA3) || defined(CDNA2)
1191
+ using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
1192
+ const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
1193
+ const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
1194
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
1195
+ #elif defined(CDNA1)
1196
+ #pragma unroll
1197
+ for (int i = 0; i < 2; ++i) {
1198
+ using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
1199
+ const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
1200
+ const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
1201
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
1202
+ }
1203
+ #else
1204
+ GGML_UNUSED_VARS(D, A, B);
1205
+ NO_DEVICE_CODE;
1206
+ #endif // defined(CDNA3) || defined(CDNA2)
1207
+ #else
1208
+ GGML_UNUSED_VARS(D, A, B);
1209
+ NO_DEVICE_CODE;
1210
+ #endif // defined(AMD_WMMA_AVAILABLE)
1211
+ }
1212
+
1213
+ template <data_layout dl_d, data_layout dl_ab>
498
1214
  static __device__ __forceinline__ void mma(
499
- tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
1215
+ tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
500
1216
  #if defined(AMD_MFMA_AVAILABLE)
501
1217
  using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
502
1218
  int32x4_t * acc = (int32x4_t *) D.x;
@@ -515,6 +1231,59 @@ namespace ggml_cuda_mma {
515
1231
  acc[0],
516
1232
  0, 0, 0);
517
1233
  #endif // defined(CDNA3)
1234
+
1235
+ #elif defined(AMD_WMMA_AVAILABLE)
1236
+
1237
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1238
+ int32x8_t * acc = (int32x8_t *) D.x;
1239
+
1240
+ #if defined(RDNA4)
1241
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1242
+ int32x2_t * a_vec = (int32x2_t *) A.x;
1243
+ int32x2_t * b_vec = (int32x2_t *) B.x;
1244
+
1245
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1246
+ true,
1247
+ a_vec[0],
1248
+ true,
1249
+ b_vec[0],
1250
+ acc[0],
1251
+ true
1252
+ );
1253
+
1254
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1255
+ true,
1256
+ a_vec[1],
1257
+ true,
1258
+ b_vec[1],
1259
+ acc[0],
1260
+ true
1261
+ );
1262
+
1263
+ #elif defined(RDNA3)
1264
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1265
+ int32x4_t * a_vec = (int32x4_t *) A.x;
1266
+ int32x4_t * b_vec = (int32x4_t *) B.x;
1267
+
1268
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1269
+ true,
1270
+ a_vec[0],
1271
+ true,
1272
+ b_vec[0],
1273
+ acc[0],
1274
+ true
1275
+ );
1276
+
1277
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1278
+ true,
1279
+ a_vec[1],
1280
+ true,
1281
+ b_vec[1],
1282
+ acc[0],
1283
+ true
1284
+ );
1285
+ #endif // RDNA4
1286
+
518
1287
  #else
519
1288
  GGML_UNUSED_VARS(D, A, B);
520
1289
  NO_DEVICE_CODE;
@@ -541,9 +1310,100 @@ namespace ggml_cuda_mma {
541
1310
  acc[0],
542
1311
  0, 0, 0);
543
1312
  #endif // defined(CDNA3)
1313
+
544
1314
  #else
545
1315
  GGML_UNUSED_VARS(D, A, B);
546
1316
  NO_DEVICE_CODE;
547
1317
  #endif // AMD_MFMA_AVAILABLE
548
1318
  }
1319
+
1320
+ template <typename T1, typename T2, int J, int K>
1321
+ static __device__ __forceinline__ void mma(
1322
+ tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
1323
+ tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
1324
+ const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
1325
+ mma(D16[0], A16[0], B);
1326
+ mma(D16[1], A16[1], B);
1327
+ }
1328
+
1329
+ static __device__ __forceinline__ void mma(
1330
+ tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
1331
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1332
+ const int * Axi = (const int *) A.x;
1333
+ const int * Bxi = (const int *) B.x;
1334
+ int * Dxi = (int *) D.x;
1335
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1336
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1337
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1338
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1339
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1340
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1341
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1342
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1343
+ #else
1344
+ GGML_UNUSED_VARS(D, A, B);
1345
+ NO_DEVICE_CODE;
1346
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1347
+ }
1348
+
1349
+ static __device__ __forceinline__ void mma(
1350
+ tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
1351
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1352
+ const int * Axi = (const int *) A.x;
1353
+ const int * Bxi = (const int *) B.x;
1354
+ int * Dxi = (int *) D.x;
1355
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1356
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1357
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1358
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1359
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1360
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1361
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1362
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1363
+ #else
1364
+ GGML_UNUSED_VARS(D, A, B);
1365
+ NO_DEVICE_CODE;
1366
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1367
+ }
1368
+
1369
+ template <data_layout dl_d, data_layout dl_ab>
1370
+ static __device__ __forceinline__ void mma(
1371
+ tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
1372
+ #if defined(AMD_WMMA_AVAILABLE)
1373
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1374
+ int32x8_t * acc = (int32x8_t *) D.x;
1375
+ #if defined(RDNA4)
1376
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1377
+ int32x2_t * a_vec = (int32x2_t *) A.x;
1378
+ int32x2_t * b_vec = (int32x2_t *) B.x;
1379
+
1380
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1381
+ true,
1382
+ a_vec[0],
1383
+ true,
1384
+ b_vec[0],
1385
+ acc[0],
1386
+ false
1387
+ );
1388
+ #elif defined(RDNA3)
1389
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1390
+ int32x4_t * a_vec = (int32x4_t *) A.x;
1391
+ int32x4_t * b_vec = (int32x4_t *) B.x;
1392
+
1393
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1394
+ true,
1395
+ a_vec[0],
1396
+ true,
1397
+ b_vec[0],
1398
+ acc[0],
1399
+ false
1400
+ );
1401
+ #endif // RDNA4
1402
+ #else
1403
+ GGML_UNUSED(D);
1404
+ GGML_UNUSED(A);
1405
+ GGML_UNUSED(B);
1406
+ NO_DEVICE_CODE;
1407
+ #endif // AMD_WMMA_AVAILABLE
1408
+ }
549
1409
  }