whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -1,3 +1,4 @@
1
+ #pragma once
1
2
  // This file contains primitives that expose the tensor core PTX instructions for CUDA code.
2
3
  // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
3
4
  // The documentation for the PTX instructions can be found under:
@@ -12,23 +13,28 @@
12
13
  // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
13
14
  // All matrix tiles have ne physical 32 bit elements per warp.
14
15
  //
15
- // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
16
+ // As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
17
+ // The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
16
18
 
17
19
  #include "common.cuh"
18
20
 
21
+ // On Volta each warp is doing 4 8x8 mma operations in parallel.
22
+ // The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23
+ // However, the i indices in this file are by default permuted to simplify the index calculations.
24
+ // #define GGML_CUDA_MMA_NO_VOLTA_PERM
19
25
 
20
26
  #if CUDART_VERSION >= 11080
21
27
 
22
28
  static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
23
29
  int ret = 0;
24
30
 
25
- #ifdef NEW_MMA_AVAILABLE
31
+ #ifdef TURING_MMA_AVAILABLE
26
32
  asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
27
33
  : "=r"(ret) : "r"(x));
28
34
  #else
29
35
  GGML_UNUSED(x);
30
36
  NO_DEVICE_CODE;
31
- #endif // defined(NEW_MMA_AVAILABLE)
37
+ #endif // defined(TURING_MMA_AVAILABLE)
32
38
  return ret;
33
39
  }
34
40
 
@@ -62,22 +68,187 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
62
68
 
63
69
  namespace ggml_cuda_mma {
64
70
 
71
+ // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
72
+ // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
73
+ // In those cases the data can be split in different ways across the warp.
74
+ enum data_layout {
75
+ // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
76
+ // For the A/C matrices this means I major == row major, J major == column major.
77
+ // For the B matrix this means I major == column major, J major == row major.
78
+ // MIRRORED == Each data value is held exactly once per thread subgroup.
79
+ DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
80
+ DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81
+ DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
82
+ DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
83
+ };
84
+ // Implemented mma combinations are:
85
+ // - (I_MAJOR, I_MAJOR) -> I_MAJOR
86
+ // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
87
+ // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
88
+
89
+ static constexpr bool is_i_major(const data_layout dl) {
90
+ return dl == DATA_LAYOUT_I_MAJOR ||
91
+ dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
92
+ }
93
+
94
+ static constexpr __device__ data_layout get_input_data_layout() {
95
+ #if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
96
+ return DATA_LAYOUT_I_MAJOR_MIRRORED;
97
+ #else
98
+ return DATA_LAYOUT_I_MAJOR;
99
+ #endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
100
+ }
101
+
102
+ template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
103
+ struct tile {};
104
+
65
105
  template <int I_, int J_, typename T>
66
- struct tile {
67
- static constexpr int I = I_;
68
- static constexpr int J = J_;
69
- static constexpr int ne = I * J / WARP_SIZE;
106
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
107
+ static constexpr int I = I_;
108
+ static constexpr int J = J_;
109
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
110
+
111
+ #if defined(AMD_MFMA_AVAILABLE)
112
+ static constexpr int ne = I * J / 64;
70
113
  T x[ne] = {0};
71
114
 
115
+ static constexpr __device__ bool supported() {
116
+ if (I == 64 && J == 2) return true;
117
+ if (I == 16 && J == 8) return true;
118
+ if (I == 32 && J == 4) return true;
119
+ if (I == 16 && J == 16) return true;
120
+ if (I == 32 && J == 32) return true;
121
+ return false;
122
+ }
123
+
72
124
  static __device__ __forceinline__ int get_i(const int l) {
73
- if constexpr (I == 8 && (J == 4 || J == 8)) {
125
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
126
+ return threadIdx.x % 16;
127
+ } else if constexpr (I == 16 && J == 8) {
128
+ return threadIdx.x % 16;
129
+ } else if constexpr (I == 32 && J == 4) {
130
+ return threadIdx.x % 32;
131
+ } else if constexpr (I == 16 && J == 16) {
132
+ return threadIdx.x % 16;
133
+ } else if constexpr (I == 32 && J == 32) {
134
+ return threadIdx.x % 32;
135
+ } else {
136
+ NO_DEVICE_CODE;
137
+ return -1;
138
+ }
139
+ }
140
+
141
+ static __device__ __forceinline__ int get_j(const int l) {
142
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
143
+ return (2 * ((threadIdx.x / 16) % 2) + l);
144
+ } else if constexpr (I == 16 && J == 8) {
145
+ return 2 * (threadIdx.x / 16) + l;
146
+ } else if constexpr (I == 32 && J == 4) {
147
+ return 2 * (threadIdx.x / 32) + l;
148
+ } else if constexpr (I == 16 && J == 16) {
149
+ return 4 * (threadIdx.x / 16) + l;
150
+ } else if constexpr (I == 32 && J == 32) {
151
+ return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
152
+ } else {
153
+ NO_DEVICE_CODE;
154
+ return -1;
155
+ }
156
+ }
157
+ #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
158
+ static constexpr int ne = I * J / 32;
159
+ T x[ne] = {0};
160
+
161
+ static constexpr __device__ bool supported() {
162
+ if (I == 32 && J == 8) return true;
163
+ return false;
164
+ }
165
+
166
+ static __device__ __forceinline__ int get_i(const int l) {
167
+ if constexpr (I == 32 && J == 8) {
168
+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
169
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
170
+ #else
171
+ return (l & 2) + (threadIdx.x & ~2);
172
+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
173
+ } else {
174
+ NO_DEVICE_CODE;
175
+ return -1;
176
+ }
177
+ }
178
+
179
+ static __device__ __forceinline__ int get_j(const int l) {
180
+ if constexpr (I == 32 && J == 8) {
181
+ return (threadIdx.x & 2) + (l & (4 + 1));
182
+ } else {
183
+ NO_DEVICE_CODE;
184
+ return -1;
185
+ }
186
+ }
187
+ #elif defined(AMD_WMMA_AVAILABLE)
188
+ static constexpr int ne = I * J / 32;
189
+ T x[ne] = {0};
190
+
191
+ static constexpr __device__ bool supported() {
192
+ if (I == 16 && J == 16) return true;
193
+ if (I == 16 && J == 8) return true;
194
+ if (I == 16 && J == 4) return true;
195
+ return false;
196
+ }
197
+
198
+ static __device__ __forceinline__ int get_i(const int l) {
199
+ if constexpr (supported()) {
200
+ return threadIdx.x % 16;
201
+ } else {
202
+ NO_DEVICE_CODE;
203
+ return -1;
204
+ }
205
+ }
206
+
207
+ static __device__ __forceinline__ int get_j(const int l) {
208
+ if constexpr (I == 16 && J == 16) {
209
+ // matrix C
210
+ #if defined(RDNA3)
211
+ return 2 * l + (threadIdx.x / 16);
212
+ #else
213
+ return ne * (threadIdx.x / 16) + l;
214
+ #endif // defined(RDNA3)
215
+ } else if constexpr (I == 16 && J == 8) {
216
+ // mmq input for RDNA4
217
+ return ne * (threadIdx.x / 16) + l;
218
+ } else if constexpr (I == 16 && J == 4) {
219
+ return ne * (threadIdx.x / 16) + l;
220
+ } else {
221
+ NO_DEVICE_CODE;
222
+ return -1;
223
+ }
224
+ }
225
+ #else
226
+ static constexpr int ne = I * J / 32;
227
+ T x[ne] = {0};
228
+
229
+ static constexpr __device__ bool supported() {
230
+ if (I == 8 && J == 4) return true;
231
+ if (I == 8 && J == 8) return true;
232
+ if (I == 16 && J == 8) return true;
233
+ if (I == 16 && J == 16) return true;
234
+ if (I == 32 && J == 8) return true;
235
+ return false;
236
+ }
237
+
238
+ static __device__ __forceinline__ int get_i(const int l) {
239
+ if constexpr (I == 8 && J == 4) {
240
+ return threadIdx.x / 4;
241
+ } else if constexpr (I == 8 && J == 8) {
74
242
  return threadIdx.x / 4;
75
243
  } else if constexpr (I == 16 && J == 8) {
76
- return (l / 2) * 8 + threadIdx.x / 4;
244
+ return ((l / 2) * 8) + (threadIdx.x / 4);
77
245
  } else if constexpr (I == 16 && J == 16) {
78
- return ((l / 2) % 2) * 8 + threadIdx.x / 4;
246
+ return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
247
+ } else if constexpr (I == 32 && J == 8) {
248
+ return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
79
249
  } else {
80
- static_assert(I == -1 && J == -1, "template specialization not implemented");
250
+ NO_DEVICE_CODE;
251
+ return -1;
81
252
  }
82
253
  }
83
254
 
@@ -85,49 +256,354 @@ namespace ggml_cuda_mma {
85
256
  if constexpr (I == 8 && J == 4) {
86
257
  return threadIdx.x % 4;
87
258
  } else if constexpr (I == 8 && J == 8) {
88
- return 4 * l + threadIdx.x % 4;
259
+ return (l * 4) + (threadIdx.x % 4);
89
260
  } else if constexpr (I == 16 && J == 8) {
90
- return 2 * (threadIdx.x % 4) + l % 2;
261
+ return ((threadIdx.x % 4) * 2) + (l % 2);
91
262
  } else if constexpr (I == 16 && J == 16) {
92
- return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
263
+ return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
264
+ } else if constexpr (I == 32 && J == 8) {
265
+ return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
93
266
  } else {
94
- static_assert(I == -1 && J == -1, "template specialization not implemented");
267
+ NO_DEVICE_CODE;
268
+ return -1;
95
269
  }
96
270
  }
271
+ #endif // defined(GGML_USE_HIP)
97
272
  };
98
273
 
99
274
  template <int I_, int J_>
100
- struct tile<I_, J_, half2> {
101
- static constexpr int I = I_;
102
- static constexpr int J = J_;
275
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
276
+ static constexpr int I = I_;
277
+ static constexpr int J = J_;
278
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
279
+
280
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
281
+ static constexpr int ne = I * J / WARP_SIZE;
282
+ half2 x[ne] = {{0.0f, 0.0f}};
283
+
284
+ static constexpr __device__ bool supported() {
285
+ if (I == 32 && J == 4) return true;
286
+ return false;
287
+ }
288
+
289
+ static __device__ __forceinline__ int get_i(const int l) {
290
+ if constexpr (I == 32 && J == 4) {
291
+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
292
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
293
+ #else
294
+ return threadIdx.x;
295
+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
296
+ } else {
297
+ NO_DEVICE_CODE;
298
+ return -1;
299
+ }
300
+ }
301
+
302
+ static __device__ __forceinline__ int get_j(const int l) {
303
+ if constexpr (I == 32 && J == 4) {
304
+ return l;
305
+ } else {
306
+ NO_DEVICE_CODE;
307
+ return -1;
308
+ }
309
+ }
310
+ #elif defined(AMD_WMMA_AVAILABLE)
311
+ static constexpr int ne = I * J / 32;
312
+ half2 x[ne] = {{0.0f, 0.0f}};
313
+
314
+ static constexpr __device__ bool supported() {
315
+ if (I == 16 && J == 8) return true;
316
+ return false;
317
+ }
318
+
319
+ static __device__ __forceinline__ int get_i(const int l) {
320
+ if constexpr (I == 16 && J == 8) {
321
+ return threadIdx.x % 16;
322
+ } else {
323
+ NO_DEVICE_CODE;
324
+ return -1;
325
+ }
326
+ }
327
+
328
+ static __device__ __forceinline__ int get_j(const int l) {
329
+ if constexpr (I == 16 && J == 8) {
330
+ return 4 * (threadIdx.x / 16) + l;
331
+ } else {
332
+ NO_DEVICE_CODE;
333
+ return -1;
334
+ }
335
+ }
336
+ #else
103
337
  static constexpr int ne = I * J / WARP_SIZE;
104
338
  half2 x[ne] = {{0.0f, 0.0f}};
105
339
 
340
+ static constexpr __device__ bool supported() {
341
+ if (I == 8 && J == 4) return true;
342
+ if (I == 8 && J == 8) return true;
343
+ if (I == 16 && J == 8) return true;
344
+ if (I == 16 && J == 16) return true;
345
+ if (I == 32 && J == 8) return true;
346
+ return false;
347
+ }
348
+
106
349
  static __device__ __forceinline__ int get_i(const int l) {
107
350
  if constexpr (I == 8 && J == 8) {
108
351
  return threadIdx.x / 4;
109
352
  } else if constexpr (I == 16 && J == 4) {
110
- return l * 8 + threadIdx.x / 4;
353
+ return (l * 8) + (threadIdx.x / 4);
111
354
  } else if constexpr (I == 16 && J == 8) {
112
- return (l % 2) * 8 + threadIdx.x / 4;
355
+ return ((l % 2) * 8) + (threadIdx.x / 4);
356
+ } else if constexpr (I == 32 && J == 8) {
357
+ return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
113
358
  } else {
114
- static_assert(I == -1 && J == -1, "template specialization not implemented");
359
+ NO_DEVICE_CODE;
360
+ return -1;
115
361
  }
116
362
  }
117
363
 
118
364
  static __device__ __forceinline__ int get_j(const int l) {
119
365
  if constexpr (I == 8 && J == 8) {
120
- return l * 4 + threadIdx.x % 4;
366
+ return (l * 4) + (threadIdx.x % 4);
121
367
  } else if constexpr (I == 16 && J == 4) {
122
368
  return threadIdx.x % 4;
123
369
  } else if constexpr (I == 16 && J == 8) {
124
- return (l / 2) * 4 + threadIdx.x % 4;
370
+ return ((l / 2) * 4) + (threadIdx.x % 4);
371
+ } else if constexpr (I == 32 && J == 8) {
372
+ return ((l & 2) * 2) + (threadIdx.x % 4);
125
373
  } else {
126
- static_assert(I == -1 && J == -1, "template specialization not implemented");
374
+ NO_DEVICE_CODE;
375
+ return -1;
127
376
  }
128
377
  }
378
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
129
379
  };
130
380
 
381
+ template <int I_, int J_>
382
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
383
+ static constexpr int I = I_;
384
+ static constexpr int J = J_;
385
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
386
+
387
+ #if defined(AMD_WMMA_AVAILABLE)
388
+ static constexpr int ne = I * J / 32;
389
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
390
+
391
+ static constexpr __device__ bool supported() {
392
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
393
+ }
394
+
395
+ static __device__ __forceinline__ int get_i(const int l) {
396
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
397
+ }
398
+
399
+ static __device__ __forceinline__ int get_j(const int l) {
400
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
401
+ }
402
+ #else
403
+ static constexpr int ne = I * J / WARP_SIZE;
404
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
405
+
406
+ static constexpr __device__ bool supported() {
407
+ if (I == 8 && J == 8) return true;
408
+ if (I == 16 && J == 4) return true;
409
+ if (I == 16 && J == 8) return true;
410
+ return false;
411
+ }
412
+
413
+ static __device__ __forceinline__ int get_i(const int l) {
414
+ if constexpr (I == 8 && J == 8) {
415
+ return threadIdx.x / 4;
416
+ } else if constexpr (I == 16 && J == 4) {
417
+ return (l * 8) + (threadIdx.x / 4);
418
+ } else if constexpr (I == 16 && J == 8) {
419
+ return ((l % 2) * 8) + (threadIdx.x / 4);
420
+ } else {
421
+ NO_DEVICE_CODE;
422
+ return -1;
423
+ }
424
+ }
425
+
426
+ static __device__ __forceinline__ int get_j(const int l) {
427
+ if constexpr (I == 8 && J == 8) {
428
+ return (l * 4) + (threadIdx.x % 4);
429
+ } else if constexpr (I == 16 && J == 4) {
430
+ return threadIdx.x % 4;
431
+ } else if constexpr (I == 16 && J == 8) {
432
+ return ((l / 2) * 4) + (threadIdx.x % 4);
433
+ } else {
434
+ NO_DEVICE_CODE;
435
+ return -1;
436
+ }
437
+ }
438
+ #endif // defined(AMD_WMMA_AVAILABLE)
439
+ };
440
+
441
+ template <int I_, int J_, typename T>
442
+ struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
443
+ static constexpr int I = I_;
444
+ static constexpr int J = J_;
445
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
446
+
447
+ static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
448
+ T x[ne] = {0};
449
+
450
+ static constexpr __device__ bool supported() {
451
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
452
+ }
453
+
454
+ static __device__ __forceinline__ int get_i(const int l) {
455
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
456
+ }
457
+
458
+ static __device__ __forceinline__ int get_j(const int l) {
459
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
460
+ }
461
+ };
462
+
463
+ template <int I_, int J_, typename T>
464
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
465
+ static constexpr int I = I_;
466
+ static constexpr int J = J_;
467
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
468
+
469
+ // RDNA3
470
+ static constexpr int ne = I * J / 32 * 2;
471
+
472
+ T x[ne] = {0};
473
+
474
+ static constexpr __device__ bool supported() {
475
+ if (I == 16 && J == 16) return true;
476
+ if (I == 16 && J == 8) return true;
477
+ if (I == 16 && J == 4) return true;
478
+ return false;
479
+ }
480
+
481
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
482
+ if constexpr (supported()) {
483
+ return threadIdx.x % 16;
484
+ } else {
485
+ NO_DEVICE_CODE;
486
+ return -1;
487
+ }
488
+ }
489
+
490
+ static __device__ __forceinline__ int get_j(const int l) {
491
+ if constexpr (supported()) {
492
+ return l;
493
+ } else {
494
+ NO_DEVICE_CODE;
495
+ return -1;
496
+ }
497
+ }
498
+ };
499
+
500
+ template <int I_, int J_>
501
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
502
+ static constexpr int I = I_;
503
+ static constexpr int J = J_;
504
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
505
+ #if defined(RDNA3)
506
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
507
+
508
+ half2 x[ne] = {{0.0f, 0.0f}};
509
+
510
+ static constexpr __device__ bool supported() {
511
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
512
+ }
513
+
514
+ static __device__ __forceinline__ int get_i(const int l) {
515
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
516
+ }
517
+
518
+ static __device__ __forceinline__ int get_j(const int l) {
519
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
520
+ }
521
+ #else // Volta
522
+ static constexpr int ne = I * J / (WARP_SIZE/4);
523
+
524
+ half2 x[ne] = {{0.0f, 0.0f}};
525
+
526
+ static constexpr __device__ bool supported() {
527
+ if (I == 8 && J == 4) return true;
528
+ return false;
529
+ }
530
+
531
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
532
+ if constexpr (I == 8 && J == 4) {
533
+ return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
534
+ } else {
535
+ NO_DEVICE_CODE;
536
+ return -1;
537
+ }
538
+ }
539
+
540
+ static __device__ __forceinline__ int get_j(const int l) {
541
+ if constexpr (I == 8 && J == 4) {
542
+ return l;
543
+ } else {
544
+ NO_DEVICE_CODE;
545
+ return -1;
546
+ }
547
+ }
548
+ #endif // defined(RDNA3)
549
+ };
550
+
551
+ template <int I_, int J_>
552
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
553
+ static constexpr int I = I_;
554
+ static constexpr int J = J_;
555
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
556
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
557
+
558
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
559
+
560
+ static constexpr __device__ bool supported() {
561
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
562
+ }
563
+
564
+ static __device__ __forceinline__ int get_i(const int l) {
565
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
566
+ }
567
+
568
+ static __device__ __forceinline__ int get_j(const int l) {
569
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
570
+ }
571
+ };
572
+
573
+ template <int I_, int J_>
574
+ struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
575
+ static constexpr int I = I_;
576
+ static constexpr int J = J_;
577
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
578
+ static constexpr int ne = I * J / (WARP_SIZE/4);
579
+
580
+ half2 x[ne] = {{0.0f, 0.0f}};
581
+
582
+ static constexpr __device__ bool supported() {
583
+ if (I == 8 && J == 4) return true;
584
+ return false;
585
+ }
586
+
587
+ static __device__ __forceinline__ int get_i(const int l) {
588
+ if constexpr (I == 8 && J == 4) {
589
+ return ((l / 2) * 4) + (threadIdx.x % 4);
590
+ } else {
591
+ NO_DEVICE_CODE;
592
+ return -1;
593
+ }
594
+ }
595
+
596
+ static __device__ __forceinline__ int get_j(const int l) {
597
+ if constexpr (I == 8 && J == 4) {
598
+ return ((threadIdx.x / 16) * 2) + (l % 2);
599
+ } else {
600
+ NO_DEVICE_CODE;
601
+ return -1;
602
+ }
603
+ }
604
+ };
605
+
606
+ #if defined(TURING_MMA_AVAILABLE)
131
607
  template <int I, int J>
132
608
  static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
133
609
  tile<I, J/2, half2> ret;
@@ -145,19 +621,68 @@ namespace ggml_cuda_mma {
145
621
 
146
622
  return ret;
147
623
  }
624
+ #else // Volta
625
+ template <int I, int J>
626
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
627
+ tile<I, J/2, half2> ret;
628
+ #pragma unroll
629
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
630
+ ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
631
+ ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
148
632
 
149
- template <int I, int J, typename T>
150
- static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
633
+ // On Volta FP16 and FP32 tiles have a different memory layout,
634
+ // for the conversion threads with an offset of 2 need to exchange half their values:
635
+ ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
636
+ 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
637
+ }
638
+ return ret;
639
+ }
640
+ #endif // defined(TURING_MMA_AVAILABLE)
641
+
642
+ template <int I, int J, typename T, data_layout dl>
643
+ static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
644
+ #if defined(AMD_MFMA_AVAILABLE)
645
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
646
+ #pragma unroll
647
+ for (int l = 0; l < t.ne; ++l) {
648
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
649
+ }
650
+ } else {
651
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
652
+ }
653
+ #elif defined(AMD_WMMA_AVAILABLE)
654
+ // All wmma layout has contiguous data when i-major.
655
+ if constexpr (is_i_major(dl)) {
656
+ // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
657
+ constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
658
+ if constexpr (sizeof(t.x) > aligned_copy_bytes) {
659
+ static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
660
+ constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
661
+ #pragma unroll
662
+ for (int i = 0; i < aligned_copy_count; ++i) {
663
+ ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
664
+ }
665
+ } else {
666
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
667
+ }
668
+ } else {
669
+ #pragma unroll
670
+ for (int l = 0; l < t.ne; ++l) {
671
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
672
+ }
673
+ }
674
+ #else
151
675
  #pragma unroll
152
676
  for (int l = 0; l < t.ne; ++l) {
153
677
  t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
154
678
  }
679
+ #endif // defined(AMD_MFMA_AVAILABLE)
155
680
  }
156
681
 
157
682
  template <typename T>
158
683
  static __device__ __forceinline__ void load_ldmatrix(
159
684
  tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
160
- #ifdef NEW_MMA_AVAILABLE
685
+ #ifdef TURING_MMA_AVAILABLE
161
686
  int * xi = (int *) t.x;
162
687
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
163
688
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -165,58 +690,94 @@ namespace ggml_cuda_mma {
165
690
  : "l"(xs));
166
691
  #else
167
692
  load_generic(t, xs0, stride);
168
- #endif // NEW_MMA_AVAILABLE
693
+ #endif // TURING_MMA_AVAILABLE
169
694
  }
170
695
 
171
696
  template <typename T>
172
697
  static __device__ __forceinline__ void load_ldmatrix(
173
698
  tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
174
- #ifdef NEW_MMA_AVAILABLE
699
+ #ifdef TURING_MMA_AVAILABLE
175
700
  int * xi = (int *) t.x;
176
701
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
177
702
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
178
703
  : "=r"(xi[0]), "=r"(xi[1])
179
704
  : "l"(xs));
180
705
  #else
181
- load_generic(xs0, stride);
182
- GGML_UNUSED(t);
183
- #endif // NEW_MMA_AVAILABLE
706
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
707
+ GGML_UNUSED_VARS(t, xs0, stride);
708
+ NO_DEVICE_CODE;
709
+ #else
710
+ load_generic(t, xs0, stride);
711
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
712
+ #endif // TURING_MMA_AVAILABLE
184
713
  }
185
714
 
186
- template <typename T>
715
+ template <typename T, data_layout dl>
187
716
  static __device__ __forceinline__ void load_ldmatrix(
188
- tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
189
- #ifdef NEW_MMA_AVAILABLE
717
+ tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
718
+ #if defined(TURING_MMA_AVAILABLE)
190
719
  int * xi = (int * ) t.x;
191
720
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
192
721
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
193
722
  : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
194
723
  : "l"(xs));
724
+ #else
725
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
726
+ #if 1
727
+ // TODO: more generic handling
728
+ static_assert(sizeof(T) == 4, "bad type size");
729
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
730
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
195
731
  #else
196
732
  load_generic(t, xs0, stride);
197
- #endif // NEW_MMA_AVAILABLE
733
+ #endif // 1
734
+ #else
735
+ load_generic(t, xs0, stride);
736
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
737
+ #endif // TURING_MMA_AVAILABLE
738
+ }
739
+
740
+ static __device__ __forceinline__ void load_ldmatrix(
741
+ tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
742
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
743
+ }
744
+
745
+ static __device__ __forceinline__ void load_ldmatrix(
746
+ tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
747
+ #pragma unroll
748
+ for (int l0 = 0; l0 < t.ne; l0 += 2) {
749
+ ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
750
+ }
751
+ }
752
+
753
+ static __device__ __forceinline__ void load_ldmatrix(
754
+ tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
755
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
756
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
757
+ #else
758
+ GGML_UNUSED_VARS(t, xs0, stride);
759
+ NO_DEVICE_CODE;
760
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
198
761
  }
199
762
 
200
763
  template <typename T>
201
764
  static __device__ __forceinline__ void load_ldmatrix_trans(
202
765
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
203
- #ifdef NEW_MMA_AVAILABLE
766
+ #ifdef TURING_MMA_AVAILABLE
204
767
  int * xi = (int * ) t.x;
205
768
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
206
769
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
207
770
  : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
208
771
  : "l"(xs));
209
772
  #else
210
- GGML_UNUSED(t);
211
- GGML_UNUSED(xs0);
212
- GGML_UNUSED(stride);
773
+ GGML_UNUSED_VARS(t, xs0, stride);
213
774
  NO_DEVICE_CODE;
214
- #endif // NEW_MMA_AVAILABLE
775
+ #endif // TURING_MMA_AVAILABLE
215
776
  }
216
777
 
217
778
  static __device__ __forceinline__ void mma(
218
779
  tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
219
- #ifdef NEW_MMA_AVAILABLE
780
+ #ifdef TURING_MMA_AVAILABLE
220
781
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
221
782
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
222
783
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -231,16 +792,14 @@ namespace ggml_cuda_mma {
231
792
  : "r"(A.x[1]), "r"(B.x[0]));
232
793
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
233
794
  #else
234
- GGML_UNUSED(D);
235
- GGML_UNUSED(A);
236
- GGML_UNUSED(B);
795
+ GGML_UNUSED_VARS(D, A, B);
237
796
  NO_DEVICE_CODE;
238
- #endif // NEW_MMA_AVAILABLE
797
+ #endif // TURING_MMA_AVAILABLE
239
798
  }
240
799
 
241
800
  static __device__ __forceinline__ void mma(
242
801
  tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
243
- #ifdef NEW_MMA_AVAILABLE
802
+ #ifdef TURING_MMA_AVAILABLE
244
803
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
245
804
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
246
805
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -261,16 +820,14 @@ namespace ggml_cuda_mma {
261
820
  : "r"(A.x[3]), "r"(B.x[1]));
262
821
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
263
822
  #else
264
- GGML_UNUSED(D);
265
- GGML_UNUSED(A);
266
- GGML_UNUSED(B);
823
+ GGML_UNUSED_VARS(D, A, B);
267
824
  NO_DEVICE_CODE;
268
- #endif // NEW_MMA_AVAILABLE
825
+ #endif // TURING_MMA_AVAILABLE
269
826
  }
270
827
 
271
828
  static __device__ __forceinline__ void mma(
272
829
  tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
273
- #ifdef NEW_MMA_AVAILABLE
830
+ #ifdef TURING_MMA_AVAILABLE
274
831
  const int * Axi = (const int *) A.x;
275
832
  const int * Bxi = (const int *) B.x;
276
833
  int * Dxi = (int *) D.x;
@@ -288,16 +845,14 @@ namespace ggml_cuda_mma {
288
845
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
289
846
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
290
847
  #else
291
- GGML_UNUSED(D);
292
- GGML_UNUSED(A);
293
- GGML_UNUSED(B);
848
+ GGML_UNUSED_VARS(D, A, B);
294
849
  NO_DEVICE_CODE;
295
- #endif // NEW_MMA_AVAILABLE
850
+ #endif // TURING_MMA_AVAILABLE
296
851
  }
297
852
 
298
853
  static __device__ __forceinline__ void mma(
299
854
  tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
300
- #ifdef NEW_MMA_AVAILABLE
855
+ #ifdef TURING_MMA_AVAILABLE
301
856
  const int * Axi = (const int *) A.x;
302
857
  const int * Bxi = (const int *) B.x;
303
858
  int * Dxi = (int *) D.x;
@@ -324,16 +879,51 @@ namespace ggml_cuda_mma {
324
879
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
325
880
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
326
881
  #else
327
- GGML_UNUSED(D);
328
- GGML_UNUSED(A);
329
- GGML_UNUSED(B);
882
+ GGML_UNUSED_VARS(D, A, B);
883
+ NO_DEVICE_CODE;
884
+ #endif // TURING_MMA_AVAILABLE
885
+ }
886
+
887
+ template <data_layout dl_ab, data_layout dl_d>
888
+ static __device__ __forceinline__ void mma(
889
+ tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
890
+ #ifdef AMPERE_MMA_AVAILABLE
891
+ const int * Axi = (const int *) A.x;
892
+ const int * Bxi = (const int *) B.x;
893
+ int * Dxi = (int *) D.x;
894
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
895
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
896
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
897
+ #else
898
+ GGML_UNUSED_VARS(D, A, B);
330
899
  NO_DEVICE_CODE;
331
- #endif // NEW_MMA_AVAILABLE
900
+ #endif // AMPERE_MMA_AVAILABLE
901
+ }
902
+
903
+ static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
904
+ const tile<16, 8, int> & A,
905
+ const tile<8, 8, int> & B,
906
+ uint32_t a_scale,
907
+ uint32_t b_scale) {
908
+ #ifdef BLACKWELL_MMA_AVAILABLE
909
+ const int * Axi = (const int *) A.x;
910
+ const int * Bxi = (const int *) B.x;
911
+ float * Dxi = (float *) D.x;
912
+
913
+ asm volatile(
914
+ "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
915
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
916
+ "%10, {0, 0}, %11, {0, 0};"
917
+ : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
918
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
919
+ #else
920
+ GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
921
+ #endif // BLACKWELL_MMA_AVAILABLE
332
922
  }
333
923
 
334
924
  static __device__ __forceinline__ void mma(
335
925
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
336
- #ifdef NEW_MMA_AVAILABLE
926
+ #ifdef TURING_MMA_AVAILABLE
337
927
  const int * Axi = (const int *) A.x;
338
928
  const int * Bxi = (const int *) B.x;
339
929
  int * Dxi = (int *) D.x;
@@ -351,16 +941,30 @@ namespace ggml_cuda_mma {
351
941
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
352
942
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
353
943
  #else
354
- GGML_UNUSED(D);
355
- GGML_UNUSED(A);
356
- GGML_UNUSED(B);
944
+ GGML_UNUSED_VARS(D, A, B);
945
+ NO_DEVICE_CODE;
946
+ #endif // TURING_MMA_AVAILABLE
947
+ }
948
+
949
+ static __device__ __forceinline__ void mma(
950
+ tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
951
+ #ifdef AMPERE_MMA_AVAILABLE
952
+ const int * Axi = (const int *) A.x;
953
+ const int * Bxi = (const int *) B.x;
954
+ int * Dxi = (int *) D.x;
955
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
956
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
957
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
958
+ #else
959
+ GGML_UNUSED_VARS(D, A, B);
357
960
  NO_DEVICE_CODE;
358
- #endif // NEW_MMA_AVAILABLE
961
+ #endif // AMPERE_MMA_AVAILABLE
359
962
  }
360
963
 
964
+ template <data_layout dl_ab, data_layout dl_d>
361
965
  static __device__ __forceinline__ void mma(
362
- tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
363
- #ifdef NEW_MMA_AVAILABLE
966
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
967
+ #ifdef TURING_MMA_AVAILABLE
364
968
  const int * Axi = (const int *) A.x;
365
969
  const int * Bxi = (const int *) B.x;
366
970
  int * Dxi = (int *) D.x;
@@ -386,11 +990,253 @@ namespace ggml_cuda_mma {
386
990
  : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
387
991
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
388
992
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
993
+ #elif defined(AMD_WMMA_AVAILABLE)
994
+ #if defined(RDNA4)
995
+ using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
996
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
997
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
998
+ const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
999
+ const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
1000
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
1001
+ #elif defined(RDNA3)
1002
+ using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
1003
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1004
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1005
+ const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
1006
+ const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
1007
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
1008
+ #else
1009
+ GGML_UNUSED_VARS(D, A, B);
1010
+ NO_DEVICE_CODE;
1011
+ #endif // RDNA4
1012
+ #else
1013
+ GGML_UNUSED_VARS(D, A, B);
1014
+ NO_DEVICE_CODE;
1015
+ #endif // TURING_MMA_AVAILABLE
1016
+ }
1017
+
1018
+ template <data_layout dl_ab, data_layout dl_d>
1019
+ static __device__ __forceinline__ void mma(
1020
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
1021
+ #if defined(AMD_WMMA_AVAILABLE)
1022
+ #if defined(RDNA4)
1023
+ using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
1024
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1025
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1026
+ const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
1027
+ const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
1028
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
1029
+ #elif defined(RDNA3)
1030
+ using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
1031
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
1032
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
1033
+ const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
1034
+ const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
1035
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
1036
+ #else
1037
+ GGML_UNUSED_VARS(D, A, B);
1038
+ NO_DEVICE_CODE;
1039
+ #endif // RDNA4
1040
+ #else
1041
+ GGML_UNUSED_VARS(D, A, B);
1042
+ NO_DEVICE_CODE;
1043
+ #endif // AMPERE_MMA_AVAILABLE
1044
+ }
1045
+
1046
+ template <data_layout dl_d, data_layout dl_ab>
1047
+ static __device__ __forceinline__ void mma(
1048
+ tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
1049
+ #if defined(AMD_MFMA_AVAILABLE)
1050
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1051
+ int32x4_t * acc = (int32x4_t *) D.x;
1052
+ #if defined(CDNA3)
1053
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
1054
+ ((int64_t *) B.x)[0],
1055
+ acc[0],
1056
+ 0, 0, 0);
1057
+ #elif defined(CDNA2) || defined(CDNA)
1058
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
1059
+ B.x[0],
1060
+ acc[0],
1061
+ 0, 0, 0);
1062
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
1063
+ B.x[1],
1064
+ acc[0],
1065
+ 0, 0, 0);
1066
+ #endif // defined(CDNA3)
1067
+
1068
+ #elif defined(AMD_WMMA_AVAILABLE)
1069
+
1070
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1071
+ int32x8_t * acc = (int32x8_t *) D.x;
1072
+
1073
+ #if defined(RDNA4)
1074
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1075
+ int32x2_t * a_vec = (int32x2_t *) A.x;
1076
+ int32x2_t * b_vec = (int32x2_t *) B.x;
1077
+
1078
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1079
+ true,
1080
+ a_vec[0],
1081
+ true,
1082
+ b_vec[0],
1083
+ acc[0],
1084
+ true
1085
+ );
1086
+
1087
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1088
+ true,
1089
+ a_vec[1],
1090
+ true,
1091
+ b_vec[1],
1092
+ acc[0],
1093
+ true
1094
+ );
1095
+
1096
+ #elif defined(RDNA3)
1097
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1098
+ int32x4_t * a_vec = (int32x4_t *) A.x;
1099
+ int32x4_t * b_vec = (int32x4_t *) B.x;
1100
+
1101
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1102
+ true,
1103
+ a_vec[0],
1104
+ true,
1105
+ b_vec[0],
1106
+ acc[0],
1107
+ true
1108
+ );
1109
+
1110
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1111
+ true,
1112
+ a_vec[1],
1113
+ true,
1114
+ b_vec[1],
1115
+ acc[0],
1116
+ true
1117
+ );
1118
+ #endif // RDNA4
1119
+
1120
+ #else
1121
+ GGML_UNUSED_VARS(D, A, B);
1122
+ NO_DEVICE_CODE;
1123
+ #endif // AMD_MFMA_AVAILABLE
1124
+ }
1125
+
1126
+ static __device__ __forceinline__ void mma(
1127
+ tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
1128
+ #if defined(AMD_MFMA_AVAILABLE)
1129
+ using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
1130
+ int32x16_t * acc = (int32x16_t *) D.x;
1131
+ #if defined(CDNA3)
1132
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
1133
+ ((int64_t *) B.x)[0],
1134
+ acc[0],
1135
+ 0, 0, 0);
1136
+ #elif defined(CDNA2) || defined(CDNA)
1137
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
1138
+ B.x[0],
1139
+ acc[0],
1140
+ 0, 0, 0);
1141
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
1142
+ B.x[1],
1143
+ acc[0],
1144
+ 0, 0, 0);
1145
+ #endif // defined(CDNA3)
1146
+
1147
+ #else
1148
+ GGML_UNUSED_VARS(D, A, B);
1149
+ NO_DEVICE_CODE;
1150
+ #endif // AMD_MFMA_AVAILABLE
1151
+ }
1152
+
1153
+ template <typename T1, typename T2, int J, int K>
1154
+ static __device__ __forceinline__ void mma(
1155
+ tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
1156
+ tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
1157
+ const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
1158
+ mma(D16[0], A16[0], B);
1159
+ mma(D16[1], A16[1], B);
1160
+ }
1161
+
1162
+ static __device__ __forceinline__ void mma(
1163
+ tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
1164
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1165
+ const int * Axi = (const int *) A.x;
1166
+ const int * Bxi = (const int *) B.x;
1167
+ int * Dxi = (int *) D.x;
1168
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1169
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1170
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1171
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1172
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
1173
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
1174
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
1175
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1176
+ #else
1177
+ GGML_UNUSED_VARS(D, A, B);
1178
+ NO_DEVICE_CODE;
1179
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1180
+ }
1181
+
1182
+ static __device__ __forceinline__ void mma(
1183
+ tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
1184
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
1185
+ const int * Axi = (const int *) A.x;
1186
+ const int * Bxi = (const int *) B.x;
1187
+ int * Dxi = (int *) D.x;
1188
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1189
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1190
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1191
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
1192
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
1193
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
1194
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
1195
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
1196
+ #else
1197
+ GGML_UNUSED_VARS(D, A, B);
1198
+ NO_DEVICE_CODE;
1199
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
1200
+ }
1201
+
1202
+ template <data_layout dl_d, data_layout dl_ab>
1203
+ static __device__ __forceinline__ void mma(
1204
+ tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
1205
+ #if defined(AMD_WMMA_AVAILABLE)
1206
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
1207
+ int32x8_t * acc = (int32x8_t *) D.x;
1208
+ #if defined(RDNA4)
1209
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
1210
+ int32x2_t * a_vec = (int32x2_t *) A.x;
1211
+ int32x2_t * b_vec = (int32x2_t *) B.x;
1212
+
1213
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
1214
+ true,
1215
+ a_vec[0],
1216
+ true,
1217
+ b_vec[0],
1218
+ acc[0],
1219
+ false
1220
+ );
1221
+ #elif defined(RDNA3)
1222
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
1223
+ int32x4_t * a_vec = (int32x4_t *) A.x;
1224
+ int32x4_t * b_vec = (int32x4_t *) B.x;
1225
+
1226
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
1227
+ true,
1228
+ a_vec[0],
1229
+ true,
1230
+ b_vec[0],
1231
+ acc[0],
1232
+ false
1233
+ );
1234
+ #endif // RDNA4
389
1235
  #else
390
1236
  GGML_UNUSED(D);
391
1237
  GGML_UNUSED(A);
392
1238
  GGML_UNUSED(B);
393
1239
  NO_DEVICE_CODE;
394
- #endif // NEW_MMA_AVAILABLE
1240
+ #endif // AMD_WMMA_AVAILABLE
395
1241
  }
396
1242
  }