whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -18,6 +18,10 @@
18
18
 
19
19
  #include "common.cuh"
20
20
 
21
+ // On Volta each warp is doing 4 8x8 mma operations in parallel.
22
+ // The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23
+ // However, the i indices in this file are by default permuted to simplify the index calculations.
24
+ // #define GGML_CUDA_MMA_NO_VOLTA_PERM
21
25
 
22
26
  #if CUDART_VERSION >= 11080
23
27
 
@@ -64,15 +68,59 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
64
68
 
65
69
  namespace ggml_cuda_mma {
66
70
 
71
+ // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
72
+ // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
73
+ // In those cases the data can be split in different ways across the warp.
74
+ enum data_layout {
75
+ // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
76
+ // For the A/C matrices this means I major == row major, J major == column major.
77
+ // For the B matrix this means I major == column major, J major == row major.
78
+ // MIRRORED == Each data value is held exactly once per thread subgroup.
79
+ DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
80
+ DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81
+ DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
82
+ DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
83
+ };
84
+ // Implemented mma combinations are:
85
+ // - (I_MAJOR, I_MAJOR) -> I_MAJOR
86
+ // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
87
+ // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
88
+
89
+ static constexpr bool is_i_major(const data_layout dl) {
90
+ return dl == DATA_LAYOUT_I_MAJOR ||
91
+ dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
92
+ }
93
+
94
+ static constexpr __device__ data_layout get_input_data_layout() {
95
+ #if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
96
+ return DATA_LAYOUT_I_MAJOR_MIRRORED;
97
+ #else
98
+ return DATA_LAYOUT_I_MAJOR;
99
+ #endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
100
+ }
101
+
102
+ template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
103
+ struct tile {};
104
+
67
105
  template <int I_, int J_, typename T>
68
- struct tile {
69
- static constexpr int I = I_;
70
- static constexpr int J = J_;
106
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
107
+ static constexpr int I = I_;
108
+ static constexpr int J = J_;
109
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
71
110
 
72
- #if defined(GGML_USE_HIP)
111
+ #if defined(AMD_MFMA_AVAILABLE)
73
112
  static constexpr int ne = I * J / 64;
74
113
  T x[ne] = {0};
75
114
 
115
+ static constexpr __device__ bool supported() {
116
+ if (I == 64 && J == 2) return true;
117
+ if (I == 16 && J == 8) return true;
118
+ if (I == 32 && J == 4) return true;
119
+ if (I == 16 && J == 16) return true;
120
+ if (I == 32 && J == 32) return true;
121
+ return false;
122
+ }
123
+
76
124
  static __device__ __forceinline__ int get_i(const int l) {
77
125
  if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
78
126
  return threadIdx.x % 16;
@@ -81,11 +129,12 @@ namespace ggml_cuda_mma {
81
129
  } else if constexpr (I == 32 && J == 4) {
82
130
  return threadIdx.x % 32;
83
131
  } else if constexpr (I == 16 && J == 16) {
84
- return 4 * (threadIdx.x / 16) + l;
132
+ return threadIdx.x % 16;
85
133
  } else if constexpr (I == 32 && J == 32) {
86
- return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
134
+ return threadIdx.x % 32;
87
135
  } else {
88
- static_assert(I == -1 && J == -1, "template specialization not implemented");
136
+ NO_DEVICE_CODE;
137
+ return -1;
89
138
  }
90
139
  }
91
140
 
@@ -97,26 +146,109 @@ namespace ggml_cuda_mma {
97
146
  } else if constexpr (I == 32 && J == 4) {
98
147
  return 2 * (threadIdx.x / 32) + l;
99
148
  } else if constexpr (I == 16 && J == 16) {
100
- return threadIdx.x % 16;
149
+ return 4 * (threadIdx.x / 16) + l;
101
150
  } else if constexpr (I == 32 && J == 32) {
102
- return threadIdx.x % 32;
151
+ return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
103
152
  } else {
104
- static_assert(I == -1 && J == -1, "template specialization not implemented");
153
+ NO_DEVICE_CODE;
154
+ return -1;
105
155
  }
106
156
  }
157
+ #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
158
+ static constexpr int ne = I * J / 32;
159
+ T x[ne] = {0};
160
+
161
+ static constexpr __device__ bool supported() {
162
+ if (I == 32 && J == 8) return true;
163
+ return false;
164
+ }
165
+
166
+ static __device__ __forceinline__ int get_i(const int l) {
167
+ if constexpr (I == 32 && J == 8) {
168
+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
169
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
107
170
  #else
171
+ return (l & 2) + (threadIdx.x & ~2);
172
+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
173
+ } else {
174
+ NO_DEVICE_CODE;
175
+ return -1;
176
+ }
177
+ }
178
+
179
+ static __device__ __forceinline__ int get_j(const int l) {
180
+ if constexpr (I == 32 && J == 8) {
181
+ return (threadIdx.x & 2) + (l & (4 + 1));
182
+ } else {
183
+ NO_DEVICE_CODE;
184
+ return -1;
185
+ }
186
+ }
187
+ #elif defined(AMD_WMMA_AVAILABLE)
108
188
  static constexpr int ne = I * J / 32;
109
189
  T x[ne] = {0};
110
190
 
191
+ static constexpr __device__ bool supported() {
192
+ if (I == 16 && J == 16) return true;
193
+ if (I == 16 && J == 8) return true;
194
+ if (I == 16 && J == 4) return true;
195
+ return false;
196
+ }
197
+
111
198
  static __device__ __forceinline__ int get_i(const int l) {
112
- if constexpr (I == 8 && (J == 4 || J == 8)) {
199
+ if constexpr (supported()) {
200
+ return threadIdx.x % 16;
201
+ } else {
202
+ NO_DEVICE_CODE;
203
+ return -1;
204
+ }
205
+ }
206
+
207
+ static __device__ __forceinline__ int get_j(const int l) {
208
+ if constexpr (I == 16 && J == 16) {
209
+ // matrix C
210
+ #if defined(RDNA3)
211
+ return 2 * l + (threadIdx.x / 16);
212
+ #else
213
+ return ne * (threadIdx.x / 16) + l;
214
+ #endif // defined(RDNA3)
215
+ } else if constexpr (I == 16 && J == 8) {
216
+ // mmq input for RDNA4
217
+ return ne * (threadIdx.x / 16) + l;
218
+ } else if constexpr (I == 16 && J == 4) {
219
+ return ne * (threadIdx.x / 16) + l;
220
+ } else {
221
+ NO_DEVICE_CODE;
222
+ return -1;
223
+ }
224
+ }
225
+ #else
226
+ static constexpr int ne = I * J / 32;
227
+ T x[ne] = {0};
228
+
229
+ static constexpr __device__ bool supported() {
230
+ if (I == 8 && J == 4) return true;
231
+ if (I == 8 && J == 8) return true;
232
+ if (I == 16 && J == 8) return true;
233
+ if (I == 16 && J == 16) return true;
234
+ if (I == 32 && J == 8) return true;
235
+ return false;
236
+ }
237
+
238
+ static __device__ __forceinline__ int get_i(const int l) {
239
+ if constexpr (I == 8 && J == 4) {
240
+ return threadIdx.x / 4;
241
+ } else if constexpr (I == 8 && J == 8) {
113
242
  return threadIdx.x / 4;
114
243
  } else if constexpr (I == 16 && J == 8) {
115
- return (l / 2) * 8 + threadIdx.x / 4;
244
+ return ((l / 2) * 8) + (threadIdx.x / 4);
116
245
  } else if constexpr (I == 16 && J == 16) {
117
- return ((l / 2) % 2) * 8 + threadIdx.x / 4;
246
+ return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
247
+ } else if constexpr (I == 32 && J == 8) {
248
+ return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
118
249
  } else {
119
- static_assert(I == -1 && J == -1, "template specialization not implemented");
250
+ NO_DEVICE_CODE;
251
+ return -1;
120
252
  }
121
253
  }
122
254
 
@@ -124,82 +256,354 @@ namespace ggml_cuda_mma {
124
256
  if constexpr (I == 8 && J == 4) {
125
257
  return threadIdx.x % 4;
126
258
  } else if constexpr (I == 8 && J == 8) {
127
- return 4 * l + threadIdx.x % 4;
259
+ return (l * 4) + (threadIdx.x % 4);
128
260
  } else if constexpr (I == 16 && J == 8) {
129
- return 2 * (threadIdx.x % 4) + l % 2;
261
+ return ((threadIdx.x % 4) * 2) + (l % 2);
130
262
  } else if constexpr (I == 16 && J == 16) {
131
- return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
263
+ return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
264
+ } else if constexpr (I == 32 && J == 8) {
265
+ return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
132
266
  } else {
133
- static_assert(I == -1 && J == -1, "template specialization not implemented");
267
+ NO_DEVICE_CODE;
268
+ return -1;
134
269
  }
135
270
  }
136
271
  #endif // defined(GGML_USE_HIP)
137
272
  };
138
273
 
139
274
  template <int I_, int J_>
140
- struct tile<I_, J_, half2> {
141
- static constexpr int I = I_;
142
- static constexpr int J = J_;
275
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
276
+ static constexpr int I = I_;
277
+ static constexpr int J = J_;
278
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
279
+
280
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
281
+ static constexpr int ne = I * J / WARP_SIZE;
282
+ half2 x[ne] = {{0.0f, 0.0f}};
283
+
284
+ static constexpr __device__ bool supported() {
285
+ if (I == 32 && J == 4) return true;
286
+ return false;
287
+ }
288
+
289
+ static __device__ __forceinline__ int get_i(const int l) {
290
+ if constexpr (I == 32 && J == 4) {
291
+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
292
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
293
+ #else
294
+ return threadIdx.x;
295
+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
296
+ } else {
297
+ NO_DEVICE_CODE;
298
+ return -1;
299
+ }
300
+ }
301
+
302
+ static __device__ __forceinline__ int get_j(const int l) {
303
+ if constexpr (I == 32 && J == 4) {
304
+ return l;
305
+ } else {
306
+ NO_DEVICE_CODE;
307
+ return -1;
308
+ }
309
+ }
310
+ #elif defined(AMD_WMMA_AVAILABLE)
311
+ static constexpr int ne = I * J / 32;
312
+ half2 x[ne] = {{0.0f, 0.0f}};
313
+
314
+ static constexpr __device__ bool supported() {
315
+ if (I == 16 && J == 8) return true;
316
+ return false;
317
+ }
318
+
319
+ static __device__ __forceinline__ int get_i(const int l) {
320
+ if constexpr (I == 16 && J == 8) {
321
+ return threadIdx.x % 16;
322
+ } else {
323
+ NO_DEVICE_CODE;
324
+ return -1;
325
+ }
326
+ }
327
+
328
+ static __device__ __forceinline__ int get_j(const int l) {
329
+ if constexpr (I == 16 && J == 8) {
330
+ return 4 * (threadIdx.x / 16) + l;
331
+ } else {
332
+ NO_DEVICE_CODE;
333
+ return -1;
334
+ }
335
+ }
336
+ #else
143
337
  static constexpr int ne = I * J / WARP_SIZE;
144
338
  half2 x[ne] = {{0.0f, 0.0f}};
145
339
 
340
+ static constexpr __device__ bool supported() {
341
+ if (I == 8 && J == 4) return true;
342
+ if (I == 8 && J == 8) return true;
343
+ if (I == 16 && J == 8) return true;
344
+ if (I == 16 && J == 16) return true;
345
+ if (I == 32 && J == 8) return true;
346
+ return false;
347
+ }
348
+
146
349
  static __device__ __forceinline__ int get_i(const int l) {
147
350
  if constexpr (I == 8 && J == 8) {
148
351
  return threadIdx.x / 4;
149
352
  } else if constexpr (I == 16 && J == 4) {
150
- return l * 8 + threadIdx.x / 4;
353
+ return (l * 8) + (threadIdx.x / 4);
151
354
  } else if constexpr (I == 16 && J == 8) {
152
- return (l % 2) * 8 + threadIdx.x / 4;
355
+ return ((l % 2) * 8) + (threadIdx.x / 4);
356
+ } else if constexpr (I == 32 && J == 8) {
357
+ return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
153
358
  } else {
154
- static_assert(I == -1 && J == -1, "template specialization not implemented");
359
+ NO_DEVICE_CODE;
360
+ return -1;
155
361
  }
156
362
  }
157
363
 
158
364
  static __device__ __forceinline__ int get_j(const int l) {
159
365
  if constexpr (I == 8 && J == 8) {
160
- return l * 4 + threadIdx.x % 4;
366
+ return (l * 4) + (threadIdx.x % 4);
161
367
  } else if constexpr (I == 16 && J == 4) {
162
368
  return threadIdx.x % 4;
163
369
  } else if constexpr (I == 16 && J == 8) {
164
- return (l / 2) * 4 + threadIdx.x % 4;
370
+ return ((l / 2) * 4) + (threadIdx.x % 4);
371
+ } else if constexpr (I == 32 && J == 8) {
372
+ return ((l & 2) * 2) + (threadIdx.x % 4);
165
373
  } else {
166
- static_assert(I == -1 && J == -1, "template specialization not implemented");
374
+ NO_DEVICE_CODE;
375
+ return -1;
167
376
  }
168
377
  }
378
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
169
379
  };
170
380
 
171
381
  template <int I_, int J_>
172
- struct tile<I_, J_, nv_bfloat162> {
173
- static constexpr int I = I_;
174
- static constexpr int J = J_;
382
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
383
+ static constexpr int I = I_;
384
+ static constexpr int J = J_;
385
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
386
+
387
+ #if defined(AMD_WMMA_AVAILABLE)
388
+ static constexpr int ne = I * J / 32;
389
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
390
+
391
+ static constexpr __device__ bool supported() {
392
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
393
+ }
394
+
395
+ static __device__ __forceinline__ int get_i(const int l) {
396
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
397
+ }
398
+
399
+ static __device__ __forceinline__ int get_j(const int l) {
400
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
401
+ }
402
+ #else
175
403
  static constexpr int ne = I * J / WARP_SIZE;
176
404
  nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
177
405
 
406
+ static constexpr __device__ bool supported() {
407
+ if (I == 8 && J == 8) return true;
408
+ if (I == 16 && J == 4) return true;
409
+ if (I == 16 && J == 8) return true;
410
+ return false;
411
+ }
412
+
178
413
  static __device__ __forceinline__ int get_i(const int l) {
179
414
  if constexpr (I == 8 && J == 8) {
180
415
  return threadIdx.x / 4;
181
416
  } else if constexpr (I == 16 && J == 4) {
182
- return l * 8 + threadIdx.x / 4;
417
+ return (l * 8) + (threadIdx.x / 4);
183
418
  } else if constexpr (I == 16 && J == 8) {
184
- return (l % 2) * 8 + threadIdx.x / 4;
419
+ return ((l % 2) * 8) + (threadIdx.x / 4);
185
420
  } else {
186
- static_assert(I == -1 && J == -1, "template specialization not implemented");
421
+ NO_DEVICE_CODE;
422
+ return -1;
187
423
  }
188
424
  }
189
425
 
190
426
  static __device__ __forceinline__ int get_j(const int l) {
191
427
  if constexpr (I == 8 && J == 8) {
192
- return l * 4 + threadIdx.x % 4;
428
+ return (l * 4) + (threadIdx.x % 4);
193
429
  } else if constexpr (I == 16 && J == 4) {
194
430
  return threadIdx.x % 4;
195
431
  } else if constexpr (I == 16 && J == 8) {
196
- return (l / 2) * 4 + threadIdx.x % 4;
432
+ return ((l / 2) * 4) + (threadIdx.x % 4);
433
+ } else {
434
+ NO_DEVICE_CODE;
435
+ return -1;
436
+ }
437
+ }
438
+ #endif // defined(AMD_WMMA_AVAILABLE)
439
+ };
440
+
441
+ template <int I_, int J_, typename T>
442
+ struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
443
+ static constexpr int I = I_;
444
+ static constexpr int J = J_;
445
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
446
+
447
+ static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
448
+ T x[ne] = {0};
449
+
450
+ static constexpr __device__ bool supported() {
451
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
452
+ }
453
+
454
+ static __device__ __forceinline__ int get_i(const int l) {
455
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
456
+ }
457
+
458
+ static __device__ __forceinline__ int get_j(const int l) {
459
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
460
+ }
461
+ };
462
+
463
+ template <int I_, int J_, typename T>
464
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
465
+ static constexpr int I = I_;
466
+ static constexpr int J = J_;
467
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
468
+
469
+ // RDNA3
470
+ static constexpr int ne = I * J / 32 * 2;
471
+
472
+ T x[ne] = {0};
473
+
474
+ static constexpr __device__ bool supported() {
475
+ if (I == 16 && J == 16) return true;
476
+ if (I == 16 && J == 8) return true;
477
+ if (I == 16 && J == 4) return true;
478
+ return false;
479
+ }
480
+
481
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
482
+ if constexpr (supported()) {
483
+ return threadIdx.x % 16;
484
+ } else {
485
+ NO_DEVICE_CODE;
486
+ return -1;
487
+ }
488
+ }
489
+
490
+ static __device__ __forceinline__ int get_j(const int l) {
491
+ if constexpr (supported()) {
492
+ return l;
493
+ } else {
494
+ NO_DEVICE_CODE;
495
+ return -1;
496
+ }
497
+ }
498
+ };
499
+
500
+ template <int I_, int J_>
501
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
502
+ static constexpr int I = I_;
503
+ static constexpr int J = J_;
504
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
505
+ #if defined(RDNA3)
506
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
507
+
508
+ half2 x[ne] = {{0.0f, 0.0f}};
509
+
510
+ static constexpr __device__ bool supported() {
511
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
512
+ }
513
+
514
+ static __device__ __forceinline__ int get_i(const int l) {
515
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
516
+ }
517
+
518
+ static __device__ __forceinline__ int get_j(const int l) {
519
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
520
+ }
521
+ #else // Volta
522
+ static constexpr int ne = I * J / (WARP_SIZE/4);
523
+
524
+ half2 x[ne] = {{0.0f, 0.0f}};
525
+
526
+ static constexpr __device__ bool supported() {
527
+ if (I == 8 && J == 4) return true;
528
+ return false;
529
+ }
530
+
531
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
532
+ if constexpr (I == 8 && J == 4) {
533
+ return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
197
534
  } else {
198
- static_assert(I == -1 && J == -1, "template specialization not implemented");
535
+ NO_DEVICE_CODE;
536
+ return -1;
199
537
  }
200
538
  }
539
+
540
+ static __device__ __forceinline__ int get_j(const int l) {
541
+ if constexpr (I == 8 && J == 4) {
542
+ return l;
543
+ } else {
544
+ NO_DEVICE_CODE;
545
+ return -1;
546
+ }
547
+ }
548
+ #endif // defined(RDNA3)
549
+ };
550
+
551
+ template <int I_, int J_>
552
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
553
+ static constexpr int I = I_;
554
+ static constexpr int J = J_;
555
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
556
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
557
+
558
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
559
+
560
+ static constexpr __device__ bool supported() {
561
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
562
+ }
563
+
564
+ static __device__ __forceinline__ int get_i(const int l) {
565
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
566
+ }
567
+
568
+ static __device__ __forceinline__ int get_j(const int l) {
569
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
570
+ }
201
571
  };
202
572
 
573
+ template <int I_, int J_>
574
+ struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
575
+ static constexpr int I = I_;
576
+ static constexpr int J = J_;
577
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
578
+ static constexpr int ne = I * J / (WARP_SIZE/4);
579
+
580
+ half2 x[ne] = {{0.0f, 0.0f}};
581
+
582
+ static constexpr __device__ bool supported() {
583
+ if (I == 8 && J == 4) return true;
584
+ return false;
585
+ }
586
+
587
+ static __device__ __forceinline__ int get_i(const int l) {
588
+ if constexpr (I == 8 && J == 4) {
589
+ return ((l / 2) * 4) + (threadIdx.x % 4);
590
+ } else {
591
+ NO_DEVICE_CODE;
592
+ return -1;
593
+ }
594
+ }
595
+
596
+ static __device__ __forceinline__ int get_j(const int l) {
597
+ if constexpr (I == 8 && J == 4) {
598
+ return ((threadIdx.x / 16) * 2) + (l % 2);
599
+ } else {
600
+ NO_DEVICE_CODE;
601
+ return -1;
602
+ }
603
+ }
604
+ };
605
+
606
+ #if defined(TURING_MMA_AVAILABLE)
203
607
  template <int I, int J>
204
608
  static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
205
609
  tile<I, J/2, half2> ret;
@@ -217,9 +621,26 @@ namespace ggml_cuda_mma {
217
621
 
218
622
  return ret;
219
623
  }
624
+ #else // Volta
625
+ template <int I, int J>
626
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
627
+ tile<I, J/2, half2> ret;
628
+ #pragma unroll
629
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
630
+ ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
631
+ ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
632
+
633
+ // On Volta FP16 and FP32 tiles have a different memory layout,
634
+ // for the conversion threads with an offset of 2 need to exchange half their values:
635
+ ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
636
+ 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
637
+ }
638
+ return ret;
639
+ }
640
+ #endif // defined(TURING_MMA_AVAILABLE)
220
641
 
221
- template <int I, int J, typename T>
222
- static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
642
+ template <int I, int J, typename T, data_layout dl>
643
+ static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
223
644
  #if defined(AMD_MFMA_AVAILABLE)
224
645
  if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
225
646
  #pragma unroll
@@ -227,9 +648,28 @@ namespace ggml_cuda_mma {
227
648
  t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
228
649
  }
229
650
  } else {
230
- int64_t * xi = (int64_t *) t.x;
231
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
232
- xi[0] = xs[0];
651
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
652
+ }
653
+ #elif defined(AMD_WMMA_AVAILABLE)
654
+ // All wmma layout has contiguous data when i-major.
655
+ if constexpr (is_i_major(dl)) {
656
+ // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
657
+ constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
658
+ if constexpr (sizeof(t.x) > aligned_copy_bytes) {
659
+ static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
660
+ constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
661
+ #pragma unroll
662
+ for (int i = 0; i < aligned_copy_count; ++i) {
663
+ ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
664
+ }
665
+ } else {
666
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
667
+ }
668
+ } else {
669
+ #pragma unroll
670
+ for (int l = 0; l < t.ne; ++l) {
671
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
672
+ }
233
673
  }
234
674
  #else
235
675
  #pragma unroll
@@ -263,25 +703,63 @@ namespace ggml_cuda_mma {
263
703
  : "=r"(xi[0]), "=r"(xi[1])
264
704
  : "l"(xs));
265
705
  #else
266
- load_generic(xs0, stride);
267
- GGML_UNUSED(t);
706
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
707
+ GGML_UNUSED_VARS(t, xs0, stride);
708
+ NO_DEVICE_CODE;
709
+ #else
710
+ load_generic(t, xs0, stride);
711
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
268
712
  #endif // TURING_MMA_AVAILABLE
269
713
  }
270
714
 
271
- template <typename T>
715
+ template <typename T, data_layout dl>
272
716
  static __device__ __forceinline__ void load_ldmatrix(
273
- tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
717
+ tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
274
718
  #if defined(TURING_MMA_AVAILABLE)
275
719
  int * xi = (int * ) t.x;
276
720
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
277
721
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
278
722
  : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
279
723
  : "l"(xs));
724
+ #else
725
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
726
+ #if 1
727
+ // TODO: more generic handling
728
+ static_assert(sizeof(T) == 4, "bad type size");
729
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
730
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
731
+ #else
732
+ load_generic(t, xs0, stride);
733
+ #endif // 1
280
734
  #else
281
735
  load_generic(t, xs0, stride);
736
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282
737
  #endif // TURING_MMA_AVAILABLE
283
738
  }
284
739
 
740
+ static __device__ __forceinline__ void load_ldmatrix(
741
+ tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
742
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
743
+ }
744
+
745
+ static __device__ __forceinline__ void load_ldmatrix(
746
+ tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
747
+ #pragma unroll
748
+ for (int l0 = 0; l0 < t.ne; l0 += 2) {
749
+ ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
750
+ }
751
+ }
752
+
753
+ static __device__ __forceinline__ void load_ldmatrix(
754
+ tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
755
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
756
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
757
+ #else
758
+ GGML_UNUSED_VARS(t, xs0, stride);
759
+ NO_DEVICE_CODE;
760
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
761
+ }
762
+
285
763
  template <typename T>
286
764
  static __device__ __forceinline__ void load_ldmatrix_trans(
287
765
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
@@ -406,8 +884,9 @@ namespace ggml_cuda_mma {
406
884
  #endif // TURING_MMA_AVAILABLE
407
885
  }
408
886
 
887
+ template <data_layout dl_ab, data_layout dl_d>
409
888
  static __device__ __forceinline__ void mma(
410
- tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
889
+ tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
411
890
  #ifdef AMPERE_MMA_AVAILABLE
412
891
  const int * Axi = (const int *) A.x;
413
892
  const int * Bxi = (const int *) B.x;
@@ -421,6 +900,27 @@ namespace ggml_cuda_mma {
421
900
  #endif // AMPERE_MMA_AVAILABLE
422
901
  }
423
902
 
903
+ static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
904
+ const tile<16, 8, int> & A,
905
+ const tile<8, 8, int> & B,
906
+ uint32_t a_scale,
907
+ uint32_t b_scale) {
908
+ #ifdef BLACKWELL_MMA_AVAILABLE
909
+ const int * Axi = (const int *) A.x;
910
+ const int * Bxi = (const int *) B.x;
911
+ float * Dxi = (float *) D.x;
912
+
913
+ asm volatile(
914
+ "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
915
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
916
+ "%10, {0, 0}, %11, {0, 0};"
917
+ : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
918
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
919
+ #else
920
+ GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
921
+ #endif // BLACKWELL_MMA_AVAILABLE
922
+ }
923
+
424
924
  static __device__ __forceinline__ void mma(
425
925
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
426
926
  #ifdef TURING_MMA_AVAILABLE
@@ -461,8 +961,9 @@ namespace ggml_cuda_mma {
461
961
  #endif // AMPERE_MMA_AVAILABLE
462
962
  }
463
963
 
964
+ template <data_layout dl_ab, data_layout dl_d>
464
965
  static __device__ __forceinline__ void mma(
465
- tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
966
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
466
967
  #ifdef TURING_MMA_AVAILABLE
467
968
  const int * Axi = (const int *) A.x;
468
969
  const int * Bxi = (const int *) B.x;
@@ -489,14 +990,62 @@ namespace ggml_cuda_mma {
489
990
  : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
490
991
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
491
992
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
993
+ #elif defined(AMD_WMMA_AVAILABLE)
994
+ #if defined(RDNA4)
995
+ using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
996
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
997
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
998
+ const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
999
+ const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
1000
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
1001
+ #elif defined(RDNA3)
1002
+ using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
1003
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1004
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1005
+ const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
1006
+ const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
1007
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
1008
+ #else
1009
+ GGML_UNUSED_VARS(D, A, B);
1010
+ NO_DEVICE_CODE;
1011
+ #endif // RDNA4
492
1012
  #else
493
1013
  GGML_UNUSED_VARS(D, A, B);
494
1014
  NO_DEVICE_CODE;
495
1015
  #endif // TURING_MMA_AVAILABLE
496
1016
  }
497
1017
 
1018
+ template <data_layout dl_ab, data_layout dl_d>
498
1019
  static __device__ __forceinline__ void mma(
499
- tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
1020
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
1021
+ #if defined(AMD_WMMA_AVAILABLE)
1022
+ #if defined(RDNA4)
1023
+ using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
1024
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1025
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1026
+ const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
1027
+ const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
1028
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
1029
+ #elif defined(RDNA3)
1030
+ using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
1031
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1032
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1033
+ const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
1034
+ const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
1035
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
1036
+ #else
1037
+ GGML_UNUSED_VARS(D, A, B);
1038
+ NO_DEVICE_CODE;
1039
+ #endif // RDNA4
1040
+ #else
1041
+ GGML_UNUSED_VARS(D, A, B);
1042
+ NO_DEVICE_CODE;
1043
+ #endif // AMPERE_MMA_AVAILABLE
1044
+ }
1045
+
1046
+ template <data_layout dl_d, data_layout dl_ab>
1047
+ static __device__ __forceinline__ void mma(
1048
+ tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
500
1049
  #if defined(AMD_MFMA_AVAILABLE)
501
1050
  using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
502
1051
  int32x4_t * acc = (int32x4_t *) D.x;
@@ -515,6 +1064,59 @@ namespace ggml_cuda_mma {
515
1064
  acc[0],
516
1065
  0, 0, 0);
517
1066
  #endif // defined(CDNA3)
1067
+
1068
+ #elif defined(AMD_WMMA_AVAILABLE)
1069
+
1070
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1071
+ int32x8_t * acc = (int32x8_t *) D.x;
1072
+
1073
+ #if defined(RDNA4)
1074
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1075
+ int32x2_t * a_vec = (int32x2_t *) A.x;
1076
+ int32x2_t * b_vec = (int32x2_t *) B.x;
1077
+
1078
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1079
+ true,
1080
+ a_vec[0],
1081
+ true,
1082
+ b_vec[0],
1083
+ acc[0],
1084
+ true
1085
+ );
1086
+
1087
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1088
+ true,
1089
+ a_vec[1],
1090
+ true,
1091
+ b_vec[1],
1092
+ acc[0],
1093
+ true
1094
+ );
1095
+
1096
+ #elif defined(RDNA3)
1097
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1098
+ int32x4_t * a_vec = (int32x4_t *) A.x;
1099
+ int32x4_t * b_vec = (int32x4_t *) B.x;
1100
+
1101
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1102
+ true,
1103
+ a_vec[0],
1104
+ true,
1105
+ b_vec[0],
1106
+ acc[0],
1107
+ true
1108
+ );
1109
+
1110
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1111
+ true,
1112
+ a_vec[1],
1113
+ true,
1114
+ b_vec[1],
1115
+ acc[0],
1116
+ true
1117
+ );
1118
+ #endif // RDNA4
1119
+
518
1120
  #else
519
1121
  GGML_UNUSED_VARS(D, A, B);
520
1122
  NO_DEVICE_CODE;
@@ -541,9 +1143,100 @@ namespace ggml_cuda_mma {
541
1143
  acc[0],
542
1144
  0, 0, 0);
543
1145
  #endif // defined(CDNA3)
1146
+
544
1147
  #else
545
1148
  GGML_UNUSED_VARS(D, A, B);
546
1149
  NO_DEVICE_CODE;
547
1150
  #endif // AMD_MFMA_AVAILABLE
548
1151
  }
1152
+
1153
+ template <typename T1, typename T2, int J, int K>
1154
+ static __device__ __forceinline__ void mma(
1155
+ tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
1156
+ tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
1157
+ const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
1158
+ mma(D16[0], A16[0], B);
1159
+ mma(D16[1], A16[1], B);
1160
+ }
1161
+
1162
+ static __device__ __forceinline__ void mma(
1163
+ tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
1164
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1165
+ const int * Axi = (const int *) A.x;
1166
+ const int * Bxi = (const int *) B.x;
1167
+ int * Dxi = (int *) D.x;
1168
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1169
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1170
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1171
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1172
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1173
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1174
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1175
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1176
+ #else
1177
+ GGML_UNUSED_VARS(D, A, B);
1178
+ NO_DEVICE_CODE;
1179
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1180
+ }
1181
+
1182
+ static __device__ __forceinline__ void mma(
1183
+ tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
1184
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1185
+ const int * Axi = (const int *) A.x;
1186
+ const int * Bxi = (const int *) B.x;
1187
+ int * Dxi = (int *) D.x;
1188
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1189
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1190
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1191
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1192
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1193
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1194
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1195
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1196
+ #else
1197
+ GGML_UNUSED_VARS(D, A, B);
1198
+ NO_DEVICE_CODE;
1199
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1200
+ }
1201
+
1202
+ template <data_layout dl_d, data_layout dl_ab>
1203
+ static __device__ __forceinline__ void mma(
1204
+ tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
1205
+ #if defined(AMD_WMMA_AVAILABLE)
1206
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1207
+ int32x8_t * acc = (int32x8_t *) D.x;
1208
+ #if defined(RDNA4)
1209
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1210
+ int32x2_t * a_vec = (int32x2_t *) A.x;
1211
+ int32x2_t * b_vec = (int32x2_t *) B.x;
1212
+
1213
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1214
+ true,
1215
+ a_vec[0],
1216
+ true,
1217
+ b_vec[0],
1218
+ acc[0],
1219
+ false
1220
+ );
1221
+ #elif defined(RDNA3)
1222
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1223
+ int32x4_t * a_vec = (int32x4_t *) A.x;
1224
+ int32x4_t * b_vec = (int32x4_t *) B.x;
1225
+
1226
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1227
+ true,
1228
+ a_vec[0],
1229
+ true,
1230
+ b_vec[0],
1231
+ acc[0],
1232
+ false
1233
+ );
1234
+ #endif // RDNA4
1235
+ #else
1236
+ GGML_UNUSED(D);
1237
+ GGML_UNUSED(A);
1238
+ GGML_UNUSED(B);
1239
+ NO_DEVICE_CODE;
1240
+ #endif // AMD_WMMA_AVAILABLE
1241
+ }
549
1242
  }