whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -10,6 +10,8 @@
10
10
 
11
11
  #include <cassert>
12
12
  #include <algorithm>
13
+ #include <limits>
14
+ #include <cmath>
13
15
 
14
16
  static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
15
17
  if (!t) {
@@ -201,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
201
203
  GGML_ABORT("unsupported op");
202
204
  }
203
205
 
206
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
207
+ return 1;
208
+ }
209
+
204
210
  int n_fuse = 1;
205
211
 
206
212
  // check if the current node can run concurrently with other nodes before it
@@ -219,13 +225,17 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
219
225
  }
220
226
 
221
227
  if (ctx->debug_graph > 0) {
222
- GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
228
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
223
229
  }
224
230
  if (ctx->debug_graph > 1) {
225
231
  GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
226
232
  GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
227
233
  GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
228
234
  GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
235
+ GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
236
+ GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
237
+ GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
238
+ GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
229
239
  GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
230
240
  GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
231
241
 
@@ -237,6 +247,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
237
247
  GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
238
248
  ggml_is_contiguous(node->src[1]), node->src[1]->name);
239
249
  }
250
+ if (node->src[2]) {
251
+ GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
252
+ ggml_is_contiguous(node->src[2]), node->src[2]->name);
253
+ }
254
+ if (node->src[3]) {
255
+ GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
256
+ ggml_is_contiguous(node->src[3]), node->src[3]->name);
257
+ }
240
258
  if (node) {
241
259
  GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
242
260
  node->name);
@@ -269,13 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
269
287
  n_fuse = ggml_metal_op_acc(ctx, idx);
270
288
  } break;
271
289
  case GGML_OP_SCALE:
272
- {
273
- n_fuse = ggml_metal_op_scale(ctx, idx);
274
- } break;
290
+ case GGML_OP_FILL:
275
291
  case GGML_OP_CLAMP:
276
- {
277
- n_fuse = ggml_metal_op_clamp(ctx, idx);
278
- } break;
292
+ case GGML_OP_LEAKY_RELU:
279
293
  case GGML_OP_SQR:
280
294
  case GGML_OP_SQRT:
281
295
  case GGML_OP_SIN:
@@ -289,11 +303,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
289
303
  {
290
304
  n_fuse = ggml_metal_op_glu(ctx, idx);
291
305
  } break;
306
+ case GGML_OP_SUM:
307
+ {
308
+ n_fuse = ggml_metal_op_sum(ctx, idx);
309
+ } break;
292
310
  case GGML_OP_SUM_ROWS:
293
311
  case GGML_OP_MEAN:
294
312
  {
295
313
  n_fuse = ggml_metal_op_sum_rows(ctx, idx);
296
314
  } break;
315
+ case GGML_OP_CUMSUM:
316
+ {
317
+ n_fuse = ggml_metal_op_cumsum(ctx, idx);
318
+ } break;
297
319
  case GGML_OP_SOFT_MAX:
298
320
  {
299
321
  n_fuse = ggml_metal_op_soft_max(ctx, idx);
@@ -311,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
311
333
  {
312
334
  n_fuse = ggml_metal_op_rwkv(ctx, idx);
313
335
  } break;
336
+ case GGML_OP_GATED_DELTA_NET:
337
+ {
338
+ n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
339
+ } break;
340
+ case GGML_OP_SOLVE_TRI:
341
+ {
342
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
343
+ } break;
314
344
  case GGML_OP_MUL_MAT:
315
345
  {
316
346
  n_fuse = ggml_metal_op_mul_mat(ctx, idx);
@@ -327,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
327
357
  {
328
358
  n_fuse = ggml_metal_op_set_rows(ctx, idx);
329
359
  } break;
360
+ case GGML_OP_DIAG:
361
+ {
362
+ n_fuse = ggml_metal_op_diag(ctx, idx);
363
+ } break;
330
364
  case GGML_OP_L2_NORM:
331
365
  {
332
366
  n_fuse = ggml_metal_op_l2_norm(ctx, idx);
@@ -348,10 +382,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
348
382
  {
349
383
  n_fuse = ggml_metal_op_im2col(ctx, idx);
350
384
  } break;
385
+ case GGML_OP_CONV_2D:
386
+ {
387
+ n_fuse = ggml_metal_op_conv_2d(ctx, idx);
388
+ } break;
351
389
  case GGML_OP_CONV_TRANSPOSE_1D:
352
390
  {
353
391
  n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
354
392
  } break;
393
+ case GGML_OP_CONV_TRANSPOSE_2D:
394
+ {
395
+ n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
396
+ } break;
355
397
  case GGML_OP_UPSCALE:
356
398
  {
357
399
  n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -376,20 +418,32 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
376
418
  {
377
419
  n_fuse = ggml_metal_op_argsort(ctx, idx);
378
420
  } break;
379
- case GGML_OP_LEAKY_RELU:
421
+ case GGML_OP_TOP_K:
422
+ {
423
+ n_fuse = ggml_metal_op_top_k(ctx, idx);
424
+ } break;
425
+ case GGML_OP_TRI:
380
426
  {
381
- n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
427
+ n_fuse = ggml_metal_op_tri(ctx, idx);
382
428
  } break;
383
429
  case GGML_OP_FLASH_ATTN_EXT:
384
430
  {
385
431
  n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
386
432
  } break;
433
+ case GGML_OP_SET:
434
+ {
435
+ n_fuse = ggml_metal_op_set(ctx, idx);
436
+ } break;
387
437
  case GGML_OP_DUP:
388
438
  case GGML_OP_CPY:
389
439
  case GGML_OP_CONT:
390
440
  {
391
441
  n_fuse = ggml_metal_op_cpy(ctx, idx);
392
442
  } break;
443
+ case GGML_OP_POOL_1D:
444
+ {
445
+ n_fuse = ggml_metal_op_pool_1d(ctx, idx);
446
+ } break;
393
447
  case GGML_OP_POOL_2D:
394
448
  {
395
449
  n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -398,7 +452,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
398
452
  {
399
453
  n_fuse = ggml_metal_op_argmax(ctx, idx);
400
454
  } break;
401
- default:
455
+ case GGML_OP_OPT_STEP_ADAMW:
456
+ {
457
+ n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
458
+ } break;
459
+ case GGML_OP_OPT_STEP_SGD:
460
+ {
461
+ n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
462
+ } break;
463
+ case GGML_OP_COUNT_EQUAL:
464
+ {
465
+ n_fuse = ggml_metal_op_count_equal(ctx, idx);
466
+ } break;
467
+ default:
402
468
  {
403
469
  GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
404
470
  GGML_ABORT("fatal error");
@@ -482,7 +548,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
482
548
  /*.dim =*/ dim,
483
549
  };
484
550
 
485
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
551
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
486
552
 
487
553
  ggml_metal_encoder_set_pipeline(enc, pipeline);
488
554
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -506,9 +572,9 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
506
572
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
507
573
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
508
574
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
509
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
575
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
510
576
 
511
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
577
+ auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
512
578
 
513
579
  ggml_metal_kargs_repeat args = {
514
580
  /*.ne00 =*/ ne00,
@@ -552,14 +618,14 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
552
618
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
553
619
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
554
620
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
555
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
621
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
556
622
 
557
623
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
558
624
  GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
559
625
  GGML_ASSERT(op->type == GGML_TYPE_F32);
560
626
 
561
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
562
- GGML_ASSERT(ggml_is_contiguous(op->src[1]));
627
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
628
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
563
629
 
564
630
  const size_t pnb1 = ((const int32_t *) op->op_params)[0];
565
631
  const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -569,14 +635,15 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
569
635
  const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
570
636
 
571
637
  if (!inplace) {
572
- // run a separete kernel to cpy src->dst
638
+ // run a separate kernel to cpy src->dst
573
639
  // not sure how to avoid this
574
640
  // TODO: make a simpler cpy_bytes kernel
575
641
 
576
642
  //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
577
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
643
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
578
644
 
579
645
  ggml_metal_kargs_cpy args = {
646
+ /*.nk0 =*/ ne00,
580
647
  /*.ne00 =*/ ne00,
581
648
  /*.ne01 =*/ ne01,
582
649
  /*.ne02 =*/ ne02,
@@ -608,10 +675,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
608
675
  }
609
676
 
610
677
  ggml_metal_kargs_bin args = {
611
- /*.ne00 =*/ ne00,
612
- /*.ne01 =*/ ne01,
613
- /*.ne02 =*/ ne02,
614
- /*.ne03 =*/ ne03,
678
+ /*.ne00 =*/ ne10,
679
+ /*.ne01 =*/ ne11,
680
+ /*.ne02 =*/ ne12,
681
+ /*.ne03 =*/ ne13,
615
682
  /*.nb00 =*/ nb00,
616
683
  /*.nb01 =*/ pnb1,
617
684
  /*.nb02 =*/ pnb2,
@@ -624,10 +691,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
624
691
  /*.nb11 =*/ nb11,
625
692
  /*.nb12 =*/ nb12,
626
693
  /*.nb13 =*/ nb13,
627
- /*.ne0 =*/ ne0,
628
- /*.ne1 =*/ ne1,
629
- /*.ne2 =*/ ne2,
630
- /*.ne3 =*/ ne3,
694
+ /*.ne0 =*/ ne10,
695
+ /*.ne1 =*/ ne11,
696
+ /*.ne2 =*/ ne12,
697
+ /*.ne3 =*/ ne13,
631
698
  /*.nb0 =*/ nb0,
632
699
  /*.nb1 =*/ pnb1,
633
700
  /*.nb2 =*/ pnb2,
@@ -636,7 +703,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
636
703
  /*.o1 =*/ { 0 },
637
704
  };
638
705
 
639
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
706
+ auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
640
707
 
641
708
  ggml_metal_encoder_set_pipeline(enc, pipeline);
642
709
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -644,14 +711,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
644
711
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
645
712
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
646
713
 
647
- const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
714
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
715
+
716
+ int nth = 1;
717
+
718
+ while (2*nth < args.ne0 && nth < nth_max) {
719
+ nth *= 2;
720
+ }
648
721
 
649
722
  ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
650
723
 
651
724
  return 1;
652
725
  }
653
726
 
654
- int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
727
+ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
655
728
  ggml_tensor * op = ctx->node(idx);
656
729
 
657
730
  ggml_metal_library_t lib = ctx->lib;
@@ -660,100 +733,82 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
660
733
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
661
734
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
662
735
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
663
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
736
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
664
737
 
665
- float scale;
666
- float bias;
667
- memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
668
- memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
738
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
669
739
 
670
- ggml_metal_kargs_scale args = {
671
- /*.scale =*/ scale,
672
- /*.bias =*/ bias,
673
- };
740
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
741
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
674
742
 
675
- int64_t n = ggml_nelements(op);
743
+ ggml_metal_kargs_unary args = {
744
+ /*.ne00 =*/ ne00,
745
+ /*.ne01 =*/ ne01,
746
+ /*.ne02 =*/ ne02,
747
+ /*.ne03 =*/ ne03,
748
+ /*.nb00 =*/ nb00,
749
+ /*.nb01 =*/ nb01,
750
+ /*.nb02 =*/ nb02,
751
+ /*.nb03 =*/ nb03,
752
+ /*.ne0 =*/ ne0,
753
+ /*.ne1 =*/ ne1,
754
+ /*.ne2 =*/ ne2,
755
+ /*.ne3 =*/ ne3,
756
+ /*.nb0 =*/ nb0,
757
+ /*.nb1 =*/ nb1,
758
+ /*.nb2 =*/ nb2,
759
+ /*.nb3 =*/ nb3,
760
+ /*.slope =*/ 0.0,
761
+ /*.scale =*/ 0.0,
762
+ /*.bias =*/ 0.0,
763
+ /*.val =*/ 0.0,
764
+ /*.min =*/ 0.0,
765
+ /*.max =*/ 0.0,
766
+ };
676
767
 
677
- if (n % 4 == 0) {
678
- n /= 4;
768
+ if (op->op == GGML_OP_LEAKY_RELU) {
769
+ args.slope = ggml_get_op_params_f32(op, 0);
679
770
  }
680
771
 
681
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
682
-
683
- ggml_metal_encoder_set_pipeline(enc, pipeline);
684
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
685
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
686
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
687
-
688
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
689
-
690
- return 1;
691
- }
692
-
693
- int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
694
- ggml_tensor * op = ctx->node(idx);
695
-
696
- ggml_metal_library_t lib = ctx->lib;
697
- ggml_metal_encoder_t enc = ctx->enc;
698
-
699
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
700
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
701
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
702
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
772
+ if (op->op == GGML_OP_SCALE) {
773
+ args.scale = ggml_get_op_params_f32(op, 0);
774
+ args.bias = ggml_get_op_params_f32(op, 1);
775
+ }
703
776
 
704
- float min;
705
- float max;
706
- memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
707
- memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
777
+ if (op->op == GGML_OP_FILL) {
778
+ args.val = ggml_get_op_params_f32(op, 0);
779
+ }
708
780
 
709
- ggml_metal_kargs_clamp args = {
710
- /*.min =*/ min,
711
- /*.max =*/ max,
712
- };
781
+ if (op->op == GGML_OP_CLAMP) {
782
+ args.min = ggml_get_op_params_f32(op, 0);
783
+ args.max = ggml_get_op_params_f32(op, 1);
784
+ }
713
785
 
714
- int64_t n = ggml_nelements(op);
786
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
715
787
 
716
- if (n % 4 == 0) {
717
- n /= 4;
788
+ if (pipeline.c4) {
789
+ args.ne00 = ne00/4;
790
+ args.ne0 = ne0/4;
718
791
  }
719
792
 
720
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
721
-
722
793
  ggml_metal_encoder_set_pipeline(enc, pipeline);
723
794
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
724
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
725
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
726
-
727
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
728
-
729
- return 1;
730
- }
795
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
796
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
731
797
 
732
- int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
733
- ggml_tensor * op = ctx->node(idx);
798
+ if (pipeline.cnt) {
799
+ const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
734
800
 
735
- ggml_metal_library_t lib = ctx->lib;
736
- ggml_metal_encoder_t enc = ctx->enc;
801
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
802
+ } else {
803
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
737
804
 
738
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
739
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
740
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
741
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
805
+ const int nth = MIN(args.ne00, nth_max);
742
806
 
743
- int64_t n = ggml_nelements(op);
807
+ const int nk0 = (args.ne00 + nth - 1)/nth;
744
808
 
745
- if (n % 4 == 0) {
746
- n /= 4;
809
+ ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
747
810
  }
748
811
 
749
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
750
-
751
- ggml_metal_encoder_set_pipeline(enc, pipeline);
752
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
753
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
754
-
755
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
756
-
757
812
  return 1;
758
813
  }
759
814
 
@@ -768,13 +823,13 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
768
823
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
769
824
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
770
825
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
771
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
826
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
772
827
 
773
828
  if (op->src[1]) {
774
829
  GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
775
830
  }
776
831
 
777
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
832
+ auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
778
833
 
779
834
  const int32_t swp = ggml_get_op_params_i32(op, 1);
780
835
  const float alpha = ggml_get_op_params_f32(op, 2);
@@ -800,18 +855,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
800
855
 
801
856
  const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
802
857
 
803
- //[encoder setComputePipelineState:pipeline];
804
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
805
- //if (src1) {
806
- // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
807
- //} else {
808
- // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
809
- //}
810
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
811
- //[encoder setBytes:&args length:sizeof(args) atIndex:3];
812
-
813
- //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
814
-
815
858
  ggml_metal_encoder_set_pipeline(enc, pipeline);
816
859
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
817
860
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -827,6 +870,43 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
827
870
  return 1;
828
871
  }
829
872
 
873
+ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
874
+ ggml_tensor * op = ctx->node(idx);
875
+
876
+ ggml_metal_library_t lib = ctx->lib;
877
+ ggml_metal_encoder_t enc = ctx->enc;
878
+
879
+ const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
880
+
881
+ ggml_metal_kargs_sum args = {
882
+ /*.np =*/ n,
883
+ };
884
+
885
+ auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
886
+
887
+ int nth = 32; // SIMD width
888
+
889
+ while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
890
+ nth *= 2;
891
+ }
892
+
893
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
894
+ nth = std::min(nth, (int) n);
895
+
896
+ const int nsg = (nth + 31) / 32;
897
+
898
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
899
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
900
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
901
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
902
+
903
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
904
+
905
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
906
+
907
+ return 1;
908
+ }
909
+
830
910
  int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
831
911
  ggml_tensor * op = ctx->node(idx);
832
912
 
@@ -836,7 +916,12 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
836
916
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
837
917
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
838
918
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
839
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
919
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
920
+
921
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
922
+
923
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
924
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
840
925
 
841
926
  ggml_metal_kargs_sum_rows args = {
842
927
  /*.ne00 =*/ ne00,
@@ -857,31 +942,28 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
857
942
  /*.nb3 =*/ nb3,
858
943
  };
859
944
 
860
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
945
+ auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
946
+
947
+ if (pipeline.c4) {
948
+ args.ne00 = ne00/4;
949
+ args.ne0 = ne0/4;
950
+ }
861
951
 
862
952
  int nth = 32; // SIMD width
863
953
 
864
- while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
954
+ while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
865
955
  nth *= 2;
866
956
  }
867
957
 
868
958
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
869
- nth = std::min(nth, ne00);
870
-
871
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
959
+ nth = std::min(nth, (int) args.ne00);
872
960
 
873
- //[encoder setComputePipelineState:pipeline];
874
- //[encoder setBytes:&args length:sizeof(args) atIndex:0];
875
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
876
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
877
- //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
878
-
879
- //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
961
+ const size_t smem = pipeline.smem;
880
962
 
881
963
  ggml_metal_encoder_set_pipeline(enc, pipeline);
882
964
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
883
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
884
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
965
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
966
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
885
967
 
886
968
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
887
969
 
@@ -890,6 +972,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
890
972
  return 1;
891
973
  }
892
974
 
975
+ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
976
+ ggml_tensor * op = ctx->node(idx);
977
+
978
+ ggml_metal_library_t lib = ctx->lib;
979
+ ggml_metal_encoder_t enc = ctx->enc;
980
+
981
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
982
+
983
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
984
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
985
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
986
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
987
+
988
+ auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
989
+
990
+ int nth = 1;
991
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
992
+ nth *= 2;
993
+ }
994
+
995
+ GGML_ASSERT(ne00 <= nth*nth);
996
+
997
+ const int64_t net0 = (ne00 + nth - 1) / nth;
998
+ const int64_t net1 = ne01;
999
+ const int64_t net2 = ne02;
1000
+ const int64_t net3 = ne03;
1001
+
1002
+ const uint64_t nbt0 = sizeof(float);
1003
+ const uint64_t nbt1 = net0*nbt0;
1004
+ const uint64_t nbt2 = net1*nbt1;
1005
+ const uint64_t nbt3 = net2*nbt2;
1006
+
1007
+ const size_t smem = GGML_PAD(32*sizeof(float), 16);
1008
+
1009
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
1010
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
1011
+
1012
+ ggml_metal_buffer_id bid_tmp = bid_dst;
1013
+ bid_tmp.offs += ggml_nbytes(op);
1014
+
1015
+ {
1016
+ ggml_metal_kargs_cumsum_blk args = {
1017
+ /*.ne00 =*/ ne00,
1018
+ /*.ne01 =*/ ne01,
1019
+ /*.ne02 =*/ ne02,
1020
+ /*.ne03 =*/ ne03,
1021
+ /*.nb00 =*/ nb00,
1022
+ /*.nb01 =*/ nb01,
1023
+ /*.nb02 =*/ nb02,
1024
+ /*.nb03 =*/ nb03,
1025
+ /*.net0 =*/ net0,
1026
+ /*.net1 =*/ net1,
1027
+ /*.net2 =*/ net2,
1028
+ /*.net3 =*/ net3,
1029
+ /*.nbt0 =*/ nbt0,
1030
+ /*.nbt1 =*/ nbt1,
1031
+ /*.nbt2 =*/ nbt2,
1032
+ /*.nbt3 =*/ nbt3,
1033
+ /*.outb =*/ ne00 > nth,
1034
+ };
1035
+
1036
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1037
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1038
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1039
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1040
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
1041
+
1042
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1043
+
1044
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1045
+ }
1046
+
1047
+ if (ne00 > nth) {
1048
+ ggml_metal_op_concurrency_reset(ctx);
1049
+
1050
+ {
1051
+ ggml_metal_kargs_cumsum_blk args = {
1052
+ /*.ne00 =*/ net0,
1053
+ /*.ne01 =*/ net1,
1054
+ /*.ne02 =*/ net2,
1055
+ /*.ne03 =*/ net3,
1056
+ /*.nb00 =*/ nbt0,
1057
+ /*.nb01 =*/ nbt1,
1058
+ /*.nb02 =*/ nbt2,
1059
+ /*.nb03 =*/ nbt3,
1060
+ /*.net0 =*/ net0,
1061
+ /*.net1 =*/ net1,
1062
+ /*.net2 =*/ net2,
1063
+ /*.net3 =*/ net3,
1064
+ /*.nbt0 =*/ nbt0,
1065
+ /*.nbt1 =*/ nbt1,
1066
+ /*.nbt2 =*/ nbt2,
1067
+ /*.nbt3 =*/ nbt3,
1068
+ /*.outb =*/ false,
1069
+ };
1070
+
1071
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1072
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1073
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1074
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1075
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
1076
+
1077
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1078
+
1079
+ ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1080
+ }
1081
+
1082
+ ggml_metal_op_concurrency_reset(ctx);
1083
+
1084
+ {
1085
+ auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1086
+
1087
+ ggml_metal_kargs_cumsum_add args = {
1088
+ /*.ne00 =*/ ne00,
1089
+ /*.ne01 =*/ ne01,
1090
+ /*.ne02 =*/ ne02,
1091
+ /*.ne03 =*/ ne03,
1092
+ /*.nb00 =*/ nb00,
1093
+ /*.nb01 =*/ nb01,
1094
+ /*.nb02 =*/ nb02,
1095
+ /*.nb03 =*/ nb03,
1096
+ /*.net0 =*/ net0,
1097
+ /*.net1 =*/ net1,
1098
+ /*.net2 =*/ net2,
1099
+ /*.net3 =*/ net3,
1100
+ /*.nbt0 =*/ nbt0,
1101
+ /*.nbt1 =*/ nbt1,
1102
+ /*.nbt2 =*/ nbt2,
1103
+ /*.nbt3 =*/ nbt3,
1104
+ };
1105
+
1106
+ ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1107
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1108
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1109
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1110
+
1111
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1112
+ }
1113
+ }
1114
+
1115
+ return 1;
1116
+ }
1117
+
893
1118
  int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
894
1119
  ggml_tensor * op = ctx->node(idx);
895
1120
 
@@ -901,28 +1126,36 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
901
1126
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
902
1127
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
903
1128
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
904
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1129
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
905
1130
 
906
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1131
+ auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
907
1132
 
908
1133
  ggml_metal_kargs_get_rows args = {
909
- /*.ne00 =*/ ne00,
910
- /*.nb01 =*/ nb01,
911
- /*.nb02 =*/ nb02,
912
- /*.ne10 =*/ ne10,
913
- /*.nb10 =*/ nb10,
914
- /*.nb11 =*/ nb11,
915
- /*.nb1 =*/ nb1,
916
- /*.nb2 =*/ nb2,
1134
+ /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
1135
+ /*.ne00 =*/ ne00,
1136
+ /*.nb01 =*/ nb01,
1137
+ /*.nb02 =*/ nb02,
1138
+ /*.nb03 =*/ nb03,
1139
+ /*.ne10 =*/ ne10,
1140
+ /*.nb10 =*/ nb10,
1141
+ /*.nb11 =*/ nb11,
1142
+ /*.nb12 =*/ nb12,
1143
+ /*.nb1 =*/ nb1,
1144
+ /*.nb2 =*/ nb2,
1145
+ /*.nb3 =*/ nb3,
917
1146
  };
918
1147
 
1148
+ const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1149
+
1150
+ const int nw0 = (args.ne00t + nth - 1)/nth;
1151
+
919
1152
  ggml_metal_encoder_set_pipeline(enc, pipeline);
920
1153
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
921
1154
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
922
1155
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
923
1156
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
924
1157
 
925
- ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
1158
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
926
1159
 
927
1160
  return 1;
928
1161
  }
@@ -938,9 +1171,9 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
938
1171
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
939
1172
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
940
1173
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
941
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1174
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
942
1175
 
943
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1176
+ auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
944
1177
 
945
1178
  const int32_t nk0 = ne0/ggml_blck_size(op->type);
946
1179
 
@@ -989,6 +1222,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
989
1222
  return 1;
990
1223
  }
991
1224
 
1225
+ int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
1226
+ ggml_tensor * op = ctx->node(idx);
1227
+
1228
+ ggml_metal_library_t lib = ctx->lib;
1229
+ ggml_metal_encoder_t enc = ctx->enc;
1230
+
1231
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
1232
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1233
+ GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
1234
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1235
+
1236
+ ggml_metal_kargs_diag args = {
1237
+ /*.ne00 =*/ne00,
1238
+ /*.ne01 =*/ne01,
1239
+ /*.ne02 =*/ne02,
1240
+ /*.ne03 =*/ne03,
1241
+ /*.nb00 =*/nb00,
1242
+ /*.nb01 =*/nb01,
1243
+ /*.nb02 =*/nb02,
1244
+ /*.nb03 =*/nb03,
1245
+ /*.ne0 =*/ne0,
1246
+ /*.ne1 =*/ne1,
1247
+ /*.ne2 =*/ne2,
1248
+ /*.ne3 =*/ne3,
1249
+ /*.nb0 =*/nb0,
1250
+ /*.nb1 =*/nb1,
1251
+ /*.nb2 =*/nb2,
1252
+ /*.nb3 =*/nb3,
1253
+ };
1254
+
1255
+ auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
1256
+
1257
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1258
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1259
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1260
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
1261
+
1262
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
1263
+
1264
+ return 1;
1265
+ }
1266
+
992
1267
  int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
993
1268
  ggml_tensor * op = ctx->node(idx);
994
1269
 
@@ -1002,7 +1277,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1002
1277
  GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1003
1278
  GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1004
1279
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1005
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1280
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1006
1281
 
1007
1282
  float scale;
1008
1283
  float max_bias;
@@ -1041,7 +1316,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1041
1316
  /*.n_head_log2 =*/ n_head_log2,
1042
1317
  };
1043
1318
 
1044
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1319
+ auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1045
1320
 
1046
1321
  int nth = 32; // SIMD width
1047
1322
 
@@ -1055,7 +1330,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1055
1330
  }
1056
1331
  }
1057
1332
 
1058
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1333
+ const size_t smem = pipeline.smem;
1059
1334
 
1060
1335
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1061
1336
  ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@@ -1090,7 +1365,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1090
1365
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1091
1366
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1092
1367
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1093
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1368
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1094
1369
 
1095
1370
  ggml_metal_kargs_ssm_conv args = {
1096
1371
  /*.ne00 =*/ ne00,
@@ -1111,18 +1386,46 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1111
1386
  /*.nb2 =*/ nb2,
1112
1387
  };
1113
1388
 
1114
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1115
-
1116
- ggml_metal_encoder_set_pipeline(enc, pipeline);
1117
- ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1118
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1119
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1120
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1389
+ // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1390
+ const bool use_batched = (ne1 > 1);
1121
1391
 
1122
- ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1392
+ if (use_batched) {
1393
+ // Determine the smallest power of 2 that's >= ne1, but <= 256
1394
+ int BATCH_SIZE;
1395
+ if (ne1 > 128) BATCH_SIZE = 256;
1396
+ else if (ne1 > 64 ) BATCH_SIZE = 128;
1397
+ else if (ne1 > 32 ) BATCH_SIZE = 64;
1398
+ else if (ne1 > 16 ) BATCH_SIZE = 32;
1399
+ else if (ne1 > 8 ) BATCH_SIZE = 16;
1400
+ else if (ne1 > 4 ) BATCH_SIZE = 8;
1401
+ else BATCH_SIZE = 2;
1123
1402
 
1124
- return 1;
1125
- }
1403
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
1404
+
1405
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1406
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1407
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1408
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1409
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1410
+
1411
+ // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1412
+ // Each threadgroup has BATCH_SIZE threads, each handling one token
1413
+ const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1414
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1415
+ } else {
1416
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1417
+
1418
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1419
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1420
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1421
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1422
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1423
+
1424
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1425
+ }
1426
+
1427
+ return 1;
1428
+ }
1126
1429
 
1127
1430
  int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1128
1431
  ggml_tensor * op = ctx->node(idx);
@@ -1145,7 +1448,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1145
1448
  GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1146
1449
  GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1147
1450
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1148
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1451
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1149
1452
 
1150
1453
  const ggml_tensor * src3 = op->src[3];
1151
1454
  const ggml_tensor * src4 = op->src[4];
@@ -1172,26 +1475,37 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1172
1475
  /*.n_seq_tokens =*/ n_seq_tokens,
1173
1476
  /*.n_seqs =*/ n_seqs,
1174
1477
  /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
1478
+ /*.nb00 =*/ nb00,
1175
1479
  /*.nb01 =*/ nb01,
1176
1480
  /*.nb02 =*/ nb02,
1177
1481
  /*.nb03 =*/ nb03,
1482
+ /*.nb10 =*/ nb10,
1178
1483
  /*.nb11 =*/ nb11,
1179
1484
  /*.nb12 =*/ nb12,
1485
+ /*.ns12 =*/ nb12/nb10,
1180
1486
  /*.nb13 =*/ nb13,
1487
+ /*.nb20 =*/ nb20,
1181
1488
  /*.nb21 =*/ nb21,
1489
+ /*.ns21 =*/ nb21/nb20,
1182
1490
  /*.nb22 =*/ nb22,
1491
+ /*.ne30 =*/ ne30,
1183
1492
  /*.nb31 =*/ nb31,
1184
1493
  /*.nb41 =*/ nb41,
1185
1494
  /*.nb42 =*/ nb42,
1495
+ /*.ns42 =*/ nb42/nb40,
1186
1496
  /*.nb43 =*/ nb43,
1187
1497
  /*.nb51 =*/ nb51,
1188
1498
  /*.nb52 =*/ nb52,
1499
+ /*.ns52 =*/ nb52/nb50,
1189
1500
  /*.nb53 =*/ nb53,
1501
+ /*.nb0 =*/ nb0,
1190
1502
  };
1191
1503
 
1192
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1504
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1193
1505
 
1194
- const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
1506
+ GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1507
+
1508
+ const size_t smem = pipeline.smem;
1195
1509
 
1196
1510
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1197
1511
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -1204,15 +1518,9 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1204
1518
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
1205
1519
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
1206
1520
 
1207
- ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
1521
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1208
1522
 
1209
- if (ne30 == 1) {
1210
- // Mamba-2
1211
- ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1212
- } else {
1213
- GGML_ASSERT(d_inner == 1);
1214
- ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
1215
- }
1523
+ ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1216
1524
 
1217
1525
  return 1;
1218
1526
  }
@@ -1226,14 +1534,14 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1226
1534
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1227
1535
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1228
1536
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1229
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1537
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1230
1538
 
1231
1539
  const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1232
1540
  const int64_t T = op->src[0]->ne[2];
1233
1541
  const int64_t C = op->ne[0];
1234
1542
  const int64_t H = op->src[0]->ne[1];
1235
1543
 
1236
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1544
+ auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1237
1545
 
1238
1546
  int ida = 0;
1239
1547
 
@@ -1258,41 +1566,298 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1258
1566
  return 1;
1259
1567
  }
1260
1568
 
1261
- int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1569
+ int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
1262
1570
  ggml_tensor * op = ctx->node(idx);
1263
1571
 
1264
1572
  ggml_metal_library_t lib = ctx->lib;
1265
1573
  ggml_metal_encoder_t enc = ctx->enc;
1266
1574
 
1575
+
1267
1576
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1268
1577
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1578
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1579
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1580
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1581
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1269
1582
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1270
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1583
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1271
1584
 
1272
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1585
+ auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
1273
1586
 
1274
- GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
1587
+ int ida = 0;
1588
+
1589
+ ggml_metal_kargs_gated_delta_net args = {
1590
+ /*.ne00 =*/ ne00,
1591
+ /*.ne01 =*/ ne01,
1592
+ /*.ne02 =*/ ne02,
1593
+ /*.ne03 =*/ ne03,
1594
+ /*.nb00 =*/ nb00,
1595
+ /*.nb01 =*/ nb01,
1596
+ /*.nb02 =*/ nb02,
1597
+ /*.nb03 =*/ nb03,
1598
+ /*.ne10 =*/ ne10,
1599
+ /*.ne11 =*/ ne11,
1600
+ /*.ne12 =*/ ne12,
1601
+ /*.ne13 =*/ ne13,
1602
+ /*.nb10 =*/ nb10,
1603
+ /*.nb11 =*/ nb11,
1604
+ /*.nb12 =*/ nb12,
1605
+ /*.nb13 =*/ nb13,
1606
+ /*.ne20 =*/ ne20,
1607
+ /*.ne21 =*/ ne21,
1608
+ /*.ne22 =*/ ne22,
1609
+ /*.ne23 =*/ ne23,
1610
+ /*.nb20 =*/ nb20,
1611
+ /*.nb21 =*/ nb21,
1612
+ /*.nb22 =*/ nb22,
1613
+ /*.nb23 =*/ nb23,
1614
+ /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
1615
+ /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
1616
+ /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
1617
+ /*.ne0 =*/ ne0,
1618
+ /*.ne1 =*/ ne1,
1619
+ /*.ne2 =*/ ne2,
1620
+ /*.ne3 =*/ ne3,
1621
+ /*.nb0 =*/ nb0,
1622
+ /*.nb1 =*/ nb1,
1623
+ /*.nb2 =*/ nb2,
1624
+ /*.nb3 =*/ nb3,
1625
+ };
1275
1626
 
1276
- // TODO: support
1277
- //const int32_t nk00 = ne00/ggml_blck_size(op->type);
1278
- const int32_t nk00 = ne00;
1627
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1628
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
1629
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
1630
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
1631
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
1632
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
1633
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
1634
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
1635
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
1279
1636
 
1280
- int nth = 32; // SIMD width
1637
+ const int nsg = pipeline.nsg;
1281
1638
 
1282
- while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1283
- nth *= 2;
1639
+ ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
1640
+
1641
+ return 1;
1642
+ }
1643
+
1644
+ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
1645
+ ggml_tensor * op = ctx->node(idx);
1646
+
1647
+ ggml_metal_library_t lib = ctx->lib;
1648
+ ggml_metal_encoder_t enc = ctx->enc;
1649
+
1650
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1651
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1652
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1653
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1654
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1655
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1656
+
1657
+ ggml_metal_kargs_solve_tri args = {
1658
+ /*.ne00 =*/ ne00,
1659
+ /*.ne01 =*/ ne01,
1660
+ /*.ne02 =*/ ne02,
1661
+ /*.ne03 =*/ ne03,
1662
+ /*.nb00 =*/ nb00,
1663
+ /*.nb01 =*/ nb01,
1664
+ /*.nb02 =*/ nb02,
1665
+ /*.nb03 =*/ nb03,
1666
+ /*.ne10 =*/ ne10,
1667
+ /*.ne11 =*/ ne11,
1668
+ /*.ne12 =*/ ne12,
1669
+ /*.ne13 =*/ ne13,
1670
+ /*.nb10 =*/ nb10,
1671
+ /*.nb11 =*/ nb11,
1672
+ /*.nb12 =*/ nb12,
1673
+ /*.nb13 =*/ nb13,
1674
+ /*.ne0 =*/ ne0,
1675
+ /*.ne1 =*/ ne1,
1676
+ /*.ne2 =*/ ne2,
1677
+ /*.ne3 =*/ ne3,
1678
+ /*.nb0 =*/ nb0,
1679
+ /*.nb1 =*/ nb1,
1680
+ /*.nb2 =*/ nb2,
1681
+ /*.nb3 =*/ nb3,
1682
+ };
1683
+
1684
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
1685
+
1686
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1687
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1688
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1689
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1690
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1691
+
1692
+ const int nsg = pipeline.nsg;
1693
+
1694
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
1695
+
1696
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
1697
+
1698
+ return 1;
1699
+ }
1700
+
1701
+ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
1702
+ ggml_tensor * op = ctx->node(idx);
1703
+
1704
+ ggml_metal_library_t lib = ctx->lib;
1705
+ ggml_metal_encoder_t enc = ctx->enc;
1706
+
1707
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1708
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1709
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1710
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1711
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1712
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1713
+
1714
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
1715
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
1716
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
1717
+
1718
+ const size_t pnb1 = ((const int32_t *) op->op_params)[0];
1719
+ const size_t pnb2 = ((const int32_t *) op->op_params)[1];
1720
+ const size_t pnb3 = ((const int32_t *) op->op_params)[2];
1721
+ const size_t offs = ((const int32_t *) op->op_params)[3];
1722
+
1723
+ const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
1724
+
1725
+ if (!inplace) {
1726
+ // run a separate kernel to cpy src->dst
1727
+ // not sure how to avoid this
1728
+ // TODO: make a simpler cpy_bytes kernel
1729
+
1730
+ //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
1731
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1732
+
1733
+ ggml_metal_kargs_cpy args = {
1734
+ /*.nk0 =*/ ne00,
1735
+ /*.ne00 =*/ ne00,
1736
+ /*.ne01 =*/ ne01,
1737
+ /*.ne02 =*/ ne02,
1738
+ /*.ne03 =*/ ne03,
1739
+ /*.nb00 =*/ nb00,
1740
+ /*.nb01 =*/ nb01,
1741
+ /*.nb02 =*/ nb02,
1742
+ /*.nb03 =*/ nb03,
1743
+ /*.ne0 =*/ ne0,
1744
+ /*.ne1 =*/ ne1,
1745
+ /*.ne2 =*/ ne2,
1746
+ /*.ne3 =*/ ne3,
1747
+ /*.nb0 =*/ nb0,
1748
+ /*.nb1 =*/ nb1,
1749
+ /*.nb2 =*/ nb2,
1750
+ /*.nb3 =*/ nb3,
1751
+ };
1752
+
1753
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1754
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1755
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1756
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1757
+
1758
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
1759
+
1760
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1761
+
1762
+ ggml_metal_op_concurrency_reset(ctx);
1284
1763
  }
1285
1764
 
1286
- nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1765
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
1766
+
1767
+ GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
1768
+
1769
+ int64_t nk0 = ne10;
1770
+ if (ggml_is_quantized(op->src[1]->type)) {
1771
+ nk0 = ne10/16;
1772
+ } else if (ggml_is_quantized(op->type)) {
1773
+ nk0 = ne10/ggml_blck_size(op->type);
1774
+ }
1775
+
1776
+ int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1777
+
1778
+ // when rows are small, we can batch them together in a single threadgroup
1779
+ int nrptg = 1;
1780
+
1781
+ // TODO: relax this constraint in the future
1782
+ if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
1783
+ if (nth > nk0) {
1784
+ nrptg = (nth + nk0 - 1)/nk0;
1785
+ nth = nk0;
1786
+
1787
+ if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1788
+ nrptg--;
1789
+ }
1790
+ }
1791
+ }
1792
+
1793
+ nth = std::min<int>(nth, nk0);
1794
+
1795
+ ggml_metal_kargs_cpy args = {
1796
+ /*.nk0 =*/ nk0,
1797
+ /*.ne00 =*/ ne10,
1798
+ /*.ne01 =*/ ne11,
1799
+ /*.ne02 =*/ ne12,
1800
+ /*.ne03 =*/ ne13,
1801
+ /*.nb00 =*/ nb10,
1802
+ /*.nb01 =*/ nb11,
1803
+ /*.nb02 =*/ nb12,
1804
+ /*.nb03 =*/ nb13,
1805
+ /*.ne0 =*/ ne10,
1806
+ /*.ne1 =*/ ne11,
1807
+ /*.ne2 =*/ ne12,
1808
+ /*.ne3 =*/ ne13,
1809
+ /*.nb0 =*/ ggml_element_size(op),
1810
+ /*.nb1 =*/ pnb1,
1811
+ /*.nb2 =*/ pnb2,
1812
+ /*.nb3 =*/ pnb3,
1813
+ };
1814
+
1815
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1816
+
1817
+ bid_dst.offs += offs;
1818
+
1819
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1820
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1821
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
1822
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1823
+
1824
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
1825
+
1826
+ return 1;
1827
+ }
1828
+
1829
+ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1830
+ ggml_tensor * op = ctx->node(idx);
1831
+
1832
+ ggml_metal_library_t lib = ctx->lib;
1833
+ ggml_metal_encoder_t enc = ctx->enc;
1834
+
1835
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1836
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1837
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1838
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1839
+
1840
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1841
+
1842
+ GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
1843
+
1844
+ int64_t nk0 = ne00;
1845
+ if (ggml_is_quantized(op->src[0]->type)) {
1846
+ nk0 = ne00/16;
1847
+ } else if (ggml_is_quantized(op->type)) {
1848
+ nk0 = ne00/ggml_blck_size(op->type);
1849
+ }
1850
+
1851
+ int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1287
1852
 
1288
1853
  // when rows are small, we can batch them together in a single threadgroup
1289
1854
  int nrptg = 1;
1290
1855
 
1291
1856
  // TODO: relax this constraint in the future
1292
1857
  if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
1293
- if (nth > nk00) {
1294
- nrptg = (nth + nk00 - 1)/nk00;
1295
- nth = nk00;
1858
+ if (nth > nk0) {
1859
+ nrptg = (nth + nk0 - 1)/nk0;
1860
+ nth = nk0;
1296
1861
 
1297
1862
  if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1298
1863
  nrptg--;
@@ -1300,10 +1865,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1300
1865
  }
1301
1866
  }
1302
1867
 
1303
- nth = std::min(nth, nk00);
1868
+ nth = std::min<int>(nth, nk0);
1304
1869
 
1305
1870
  ggml_metal_kargs_cpy args = {
1306
- /*.ne00 =*/ nk00,
1871
+ /*.nk0 =*/ nk0,
1872
+ /*.ne00 =*/ ne00,
1307
1873
  /*.ne01 =*/ ne01,
1308
1874
  /*.ne02 =*/ ne02,
1309
1875
  /*.ne03 =*/ ne03,
@@ -1321,16 +1887,66 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1321
1887
  /*.nb3 =*/ nb3,
1322
1888
  };
1323
1889
 
1890
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1891
+
1324
1892
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1325
1893
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1326
1894
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1327
1895
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1328
1896
 
1329
- ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
1897
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1898
+
1899
+ return 1;
1900
+ }
1901
+
1902
+ int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
1903
+ ggml_tensor * op = ctx->node(idx);
1904
+
1905
+ ggml_metal_library_t lib = ctx->lib;
1906
+ ggml_metal_encoder_t enc = ctx->enc;
1907
+
1908
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1909
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1910
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1911
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1912
+
1913
+ const int32_t * opts = op->op_params;
1914
+ ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1915
+
1916
+ const int32_t k0 = opts[1];
1917
+ const int32_t s0 = opts[2];
1918
+ const int32_t p0 = opts[3];
1919
+
1920
+ const int64_t IW = op->src[0]->ne[0];
1921
+ const int64_t OW = op->ne[0];
1922
+
1923
+ const int64_t np = ggml_nelements(op);
1924
+
1925
+ ggml_metal_kargs_pool_1d args_pool_1d = {
1926
+ /* .k0 = */ k0,
1927
+ /* .s0 = */ s0,
1928
+ /* .p0 = */ p0,
1929
+ /* .IW = */ IW,
1930
+ /* .OW = */ OW,
1931
+ /* .np = */ np
1932
+ };
1933
+
1934
+ auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
1935
+
1936
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1937
+ const int ntg = (np + nth - 1) / nth;
1938
+
1939
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1940
+ ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
1941
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1942
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1943
+
1944
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1330
1945
 
1331
1946
  return 1;
1332
1947
  }
1333
1948
 
1949
+
1334
1950
  int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1335
1951
  ggml_tensor * op = ctx->node(idx);
1336
1952
 
@@ -1340,7 +1956,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1340
1956
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1341
1957
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1342
1958
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1343
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1959
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1344
1960
 
1345
1961
  const int32_t * opts = op->op_params;
1346
1962
  ggml_op_pool op_pool = (ggml_op_pool) opts[0];
@@ -1376,7 +1992,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1376
1992
  /* .np = */ np
1377
1993
  };
1378
1994
 
1379
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1995
+ auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1380
1996
 
1381
1997
  const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1382
1998
  const int ntg = (np + nth - 1) / nth;
@@ -1404,7 +2020,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1404
2020
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1405
2021
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1406
2022
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1407
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2023
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1408
2024
 
1409
2025
  GGML_ASSERT(ne00 == ne10);
1410
2026
 
@@ -1426,6 +2042,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1426
2042
  (
1427
2043
  op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
1428
2044
  op->src[0]->type == GGML_TYPE_F16 ||
2045
+ op->src[0]->type == GGML_TYPE_BF16 ||
1429
2046
  op->src[0]->type == GGML_TYPE_Q4_0 ||
1430
2047
  op->src[0]->type == GGML_TYPE_Q4_1 ||
1431
2048
  op->src[0]->type == GGML_TYPE_Q5_0 ||
@@ -1440,6 +2057,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1440
2057
  op->src[0]->type == GGML_TYPE_Q4_K ||
1441
2058
  op->src[0]->type == GGML_TYPE_Q5_K ||
1442
2059
  op->src[0]->type == GGML_TYPE_Q6_K ||
2060
+ op->src[0]->type == GGML_TYPE_Q2_K ||
2061
+ op->src[0]->type == GGML_TYPE_Q3_K ||
1443
2062
  false) && (ne11 >= 4 && ne11 <= 8)
1444
2063
  )
1445
2064
  )
@@ -1468,7 +2087,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1468
2087
  const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
1469
2088
  int16_t r1ptg = 4; // num src1 rows per threadgroup
1470
2089
 
1471
- // note: not sure how optimal are those across all different hardware. there might be someting cleverer
2090
+ // note: not sure how optimal are those across all different hardware. there might be something cleverer
1472
2091
  switch (ne11) {
1473
2092
  case 2:
1474
2093
  r1ptg = 2; break;
@@ -1485,7 +2104,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1485
2104
  GGML_ABORT("unsupported ne11");
1486
2105
  };
1487
2106
 
1488
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
2107
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1489
2108
 
1490
2109
  ggml_metal_kargs_mul_mv_ext args = {
1491
2110
  /*.ne00 =*/ ne00,
@@ -1520,9 +2139,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1520
2139
  !ggml_is_transposed(op->src[1]) &&
1521
2140
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1522
2141
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1523
- props_dev->has_simdgroup_mm && ne00 >= 64 &&
1524
- (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
1525
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2142
+ props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
2143
+ //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1526
2144
 
1527
2145
  // some Metal matrix data types require aligned pointers
1528
2146
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1533,7 +2151,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1533
2151
  // default: break;
1534
2152
  //}
1535
2153
 
1536
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
2154
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
1537
2155
 
1538
2156
  ggml_metal_kargs_mul_mm args = {
1539
2157
  /*.ne00 =*/ ne00,
@@ -1558,18 +2176,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1558
2176
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1559
2177
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1560
2178
 
1561
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2179
+ const size_t smem = pipeline.smem;
1562
2180
 
1563
2181
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1564
2182
  ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1565
2183
  } else {
1566
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
2184
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1567
2185
 
1568
- const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1569
- const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1570
- const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
2186
+ const int nr0 = pipeline.nr0;
2187
+ const int nr1 = pipeline.nr1;
2188
+ const int nsg = pipeline.nsg;
1571
2189
 
1572
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2190
+ const size_t smem = pipeline.smem;
1573
2191
 
1574
2192
  ggml_metal_kargs_mul_mv args = {
1575
2193
  /*.ne00 =*/ ne00,
@@ -1646,7 +2264,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1646
2264
  GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1647
2265
  GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1648
2266
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1649
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2267
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1650
2268
 
1651
2269
  // src2 = ids
1652
2270
  GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
@@ -1700,9 +2318,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1700
2318
  nb21,
1701
2319
  };
1702
2320
 
1703
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
2321
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
1704
2322
 
1705
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2323
+ const size_t smem = pipeline.smem;
1706
2324
 
1707
2325
  GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1708
2326
 
@@ -1723,7 +2341,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1723
2341
  ggml_metal_op_concurrency_reset(ctx);
1724
2342
 
1725
2343
  {
1726
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
2344
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
1727
2345
 
1728
2346
  ggml_metal_kargs_mul_mm_id args = {
1729
2347
  /*.ne00 =*/ ne00,
@@ -1752,20 +2370,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1752
2370
  ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
1753
2371
  ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
1754
2372
 
1755
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2373
+ const size_t smem = pipeline.smem;
1756
2374
 
1757
2375
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1758
2376
 
1759
2377
  ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
1760
2378
  }
1761
2379
  } else {
1762
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
2380
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
1763
2381
 
1764
- const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1765
- const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1766
- const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
2382
+ const int nr0 = pipeline.nr0;
2383
+ const int nr1 = pipeline.nr1;
2384
+ const int nsg = pipeline.nsg;
1767
2385
 
1768
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2386
+ const size_t smem = pipeline.smem;
1769
2387
 
1770
2388
  ggml_metal_kargs_mul_mv_id args = {
1771
2389
  /*.nei0 =*/ ne20,
@@ -1849,7 +2467,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
1849
2467
  /*.nb21 =*/ nb21,
1850
2468
  };
1851
2469
 
1852
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
2470
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
1853
2471
 
1854
2472
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1855
2473
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -1875,20 +2493,118 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
1875
2493
  return (ne01 < 20) && (ne00 % 32 == 0);
1876
2494
  }
1877
2495
 
2496
+ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
2497
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2498
+
2499
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2500
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2501
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2502
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2503
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2504
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2505
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2506
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2507
+
2508
+ size_t res = 0;
2509
+
2510
+ const bool has_mask = op->src[3] != nullptr;
2511
+
2512
+ // note: the non-vec kernel requires more extra memory, so always reserve for it
2513
+ GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
2514
+
2515
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2516
+ if (false) {
2517
+ // note: always reserve the padding space to avoid graph reallocations
2518
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2519
+ const bool has_kvpad = true;
2520
+
2521
+ if (has_kvpad) {
2522
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
2523
+ nb11*ne12*ne13 +
2524
+ nb21*ne22*ne23 +
2525
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2526
+ }
2527
+ } else {
2528
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2529
+ const bool has_kvpad = true;
2530
+
2531
+ if (has_kvpad) {
2532
+ res += OP_FLASH_ATTN_EXT_NCPSG*(
2533
+ nb11*ne12*ne13 +
2534
+ nb21*ne22*ne23 +
2535
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2536
+ }
2537
+ }
2538
+
2539
+ return res;
2540
+ }
2541
+
2542
+ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2543
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2544
+
2545
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2546
+ //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2547
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2548
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2549
+ //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2550
+ //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2551
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2552
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2553
+
2554
+ size_t res = 0;
2555
+
2556
+ const bool has_mask = op->src[3] != nullptr;
2557
+
2558
+ if (!has_mask) {
2559
+ return res;
2560
+ }
2561
+
2562
+ const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
2563
+
2564
+ // this optimization is not useful for the vector kernels
2565
+ // note: always reserve the blk buffer to avoid graph reallocations
2566
+ //if (is_vec) {
2567
+ // return res;
2568
+ //}
2569
+
2570
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
2571
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2572
+
2573
+ const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2574
+ const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2575
+
2576
+ res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2577
+
2578
+ return res;
2579
+ }
2580
+
1878
2581
  size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
1879
2582
  assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1880
2583
 
1881
- const int64_t nwg = 32;
2584
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2585
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2586
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2587
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2588
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2589
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2590
+ //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2591
+ //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2592
+
2593
+ size_t res = 0;
1882
2594
 
1883
- const int64_t ne01 = op->src[0]->ne[1];
1884
- const int64_t ne02 = op->src[0]->ne[2];
1885
- const int64_t ne03 = op->src[0]->ne[3];
1886
- const int64_t ne20 = op->src[2]->ne[0];
2595
+ // note: always reserve the temp buffer to avoid graph reallocations
2596
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2597
+ if (true) {
2598
+ const int64_t nwg = 32;
2599
+ const int64_t ne01_max = std::min(ne01, 32);
1887
2600
 
1888
- // temp buffer for writing the results from each workgroup
1889
- // - ne20: the size of the Value head
1890
- // - + 2: the S and M values for each intermediate result
1891
- return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2601
+ // temp buffer for writing the results from each workgroup
2602
+ // - ne20: the size of the Value head
2603
+ // - + 2: the S and M values for each intermediate result
2604
+ res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2605
+ }
2606
+
2607
+ return res;
1892
2608
  }
1893
2609
 
1894
2610
  int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
@@ -1910,8 +2626,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
1910
2626
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1911
2627
  GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
1912
2628
 
1913
- GGML_ASSERT(ne00 % 4 == 0);
1914
- GGML_ASSERT(ne11 % 32 == 0);
2629
+ GGML_ASSERT(ne00 % 4 == 0);
1915
2630
 
1916
2631
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
1917
2632
  GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@@ -1921,8 +2636,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
1921
2636
  GGML_ASSERT(ne12 == ne22);
1922
2637
 
1923
2638
  GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
1924
- GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
1925
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2639
+ GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2640
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
1926
2641
 
1927
2642
  float scale;
1928
2643
  float max_bias;
@@ -1949,15 +2664,107 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
1949
2664
 
1950
2665
  GGML_ASSERT(ne01 < 65536);
1951
2666
 
2667
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2668
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2669
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2670
+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2671
+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2672
+
2673
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2674
+
2675
+ ggml_metal_buffer_id bid_pad = bid_dst;
2676
+ bid_pad.offs += ggml_nbytes(op);
2677
+
2678
+ ggml_metal_buffer_id bid_blk = bid_pad;
2679
+ bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2680
+
2681
+ ggml_metal_buffer_id bid_tmp = bid_blk;
2682
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
2683
+
1952
2684
  if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
1953
2685
  // half8x8 kernel
1954
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
1955
- const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
2686
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
2687
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
1956
2688
 
1957
2689
  GGML_ASSERT(nqptg <= 32);
1958
2690
  GGML_ASSERT(nqptg % 8 == 0);
1959
2691
  GGML_ASSERT(ncpsg % 32 == 0);
1960
2692
 
2693
+ bool need_sync = false;
2694
+
2695
+ const bool has_kvpad = ne11 % ncpsg != 0;
2696
+
2697
+ if (has_kvpad) {
2698
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2699
+
2700
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
2701
+ /*.ne11 =*/ne11,
2702
+ /*.ne_12_2 =*/ne12,
2703
+ /*.ne_12_3 =*/ne13,
2704
+ /*.nb11 =*/nb11,
2705
+ /*.nb12 =*/nb12,
2706
+ /*.nb13 =*/nb13,
2707
+ /*.nb21 =*/nb21,
2708
+ /*.nb22 =*/nb22,
2709
+ /*.nb23 =*/nb23,
2710
+ /*.ne31 =*/ne31,
2711
+ /*.ne32 =*/ne32,
2712
+ /*.ne33 =*/ne33,
2713
+ /*.nb31 =*/nb31,
2714
+ /*.nb32 =*/nb32,
2715
+ /*.nb33 =*/nb33,
2716
+ };
2717
+
2718
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2719
+
2720
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2721
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2722
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2723
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2724
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2725
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2726
+
2727
+ assert(ne12 == ne22);
2728
+ assert(ne13 == ne23);
2729
+
2730
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2731
+
2732
+ need_sync = true;
2733
+ }
2734
+
2735
+ if (has_mask) {
2736
+ assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2737
+
2738
+ ggml_metal_kargs_flash_attn_ext_blk args0 = {
2739
+ /*.ne01 =*/ ne01,
2740
+ /*.ne30 =*/ ne30,
2741
+ /*.ne31 =*/ ne31,
2742
+ /*.ne32 =*/ ne32,
2743
+ /*.ne33 =*/ ne33,
2744
+ /*.nb31 =*/ nb31,
2745
+ /*.nb32 =*/ nb32,
2746
+ /*.nb33 =*/ nb33,
2747
+ };
2748
+
2749
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2750
+
2751
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2752
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2753
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2754
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2755
+
2756
+ const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2757
+ const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2758
+
2759
+ ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2760
+
2761
+ need_sync = true;
2762
+ }
2763
+
2764
+ if (need_sync) {
2765
+ ggml_metal_op_concurrency_reset(ctx);
2766
+ }
2767
+
1961
2768
  const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
1962
2769
 
1963
2770
  // 2*(2*ncpsg)
@@ -1985,7 +2792,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
1985
2792
 
1986
2793
  // simdgroups per threadgroup (a.k.a. warps)
1987
2794
  //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
1988
- int32_t nsg = 4;
2795
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
1989
2796
 
1990
2797
  const size_t smem = FATTN_SMEM(nsg);
1991
2798
 
@@ -2007,6 +2814,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2007
2814
  /*.nb21 =*/ nb21,
2008
2815
  /*.nb22 =*/ nb22,
2009
2816
  /*.nb23 =*/ nb23,
2817
+ /*.ne31 =*/ ne31,
2010
2818
  /*.ne32 =*/ ne32,
2011
2819
  /*.ne33 =*/ ne33,
2012
2820
  /*.nb31 =*/ nb31,
@@ -2023,24 +2831,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2023
2831
  /*.logit_softcap =*/ logit_softcap,
2024
2832
  };
2025
2833
 
2026
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
2834
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2027
2835
 
2028
2836
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2029
2837
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2030
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2031
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2032
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2033
- if (op->src[3]) {
2034
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
2035
- } else {
2036
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
2037
- }
2038
- if (op->src[4]) {
2039
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
2040
- } else {
2041
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
2042
- }
2043
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
2838
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2839
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2840
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2841
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2842
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2843
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2844
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2845
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2044
2846
 
2045
2847
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2046
2848
 
@@ -2048,14 +2850,63 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2048
2850
  #undef FATTN_SMEM
2049
2851
  } else {
2050
2852
  // half4x4 kernel
2051
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2052
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2053
- const int64_t nkpsg = 1*ncpsg;
2853
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
2854
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2855
+ const int nhptg = 1; // heads per threadgroup
2054
2856
 
2055
2857
  GGML_ASSERT(nqptg <= 32);
2056
2858
  GGML_ASSERT(nqptg % 1 == 0);
2057
2859
  GGML_ASSERT(ncpsg % 32 == 0);
2058
2860
 
2861
+ bool need_sync = false;
2862
+
2863
+ const bool has_kvpad = ne11 % ncpsg != 0;
2864
+
2865
+ if (has_kvpad) {
2866
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2867
+
2868
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
2869
+ /*.ne11 =*/ne11,
2870
+ /*.ne_12_2 =*/ne12,
2871
+ /*.ne_12_3 =*/ne13,
2872
+ /*.nb11 =*/nb11,
2873
+ /*.nb12 =*/nb12,
2874
+ /*.nb13 =*/nb13,
2875
+ /*.nb21 =*/nb21,
2876
+ /*.nb22 =*/nb22,
2877
+ /*.nb23 =*/nb23,
2878
+ /*.ne31 =*/ne31,
2879
+ /*.ne32 =*/ne32,
2880
+ /*.ne33 =*/ne33,
2881
+ /*.nb31 =*/nb31,
2882
+ /*.nb32 =*/nb32,
2883
+ /*.nb33 =*/nb33,
2884
+ };
2885
+
2886
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2887
+
2888
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2889
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2890
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2891
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2892
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2893
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2894
+
2895
+ assert(ne12 == ne22);
2896
+ assert(ne13 == ne23);
2897
+
2898
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2899
+
2900
+ need_sync = true;
2901
+ }
2902
+
2903
+ if (need_sync) {
2904
+ ggml_metal_op_concurrency_reset(ctx);
2905
+ }
2906
+
2907
+ // note: for simplicity assume the K is larger or equal than V
2908
+ GGML_ASSERT(ne10 >= ne20);
2909
+
2059
2910
  // ne00 + 2*ncpsg*(nsg)
2060
2911
  // for each query, we load it as f16 in shared memory (ne00)
2061
2912
  // and store the soft_max values and the mask
@@ -2063,28 +2914,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2063
2914
  // ne20*(nsg)
2064
2915
  // each simdgroup has a full f32 head vector in shared mem to accumulate results
2065
2916
  //
2066
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
2067
-
2068
- int64_t nsgmax = 2;
2069
- while (true) {
2070
- const size_t smem = FATTN_SMEM(nsgmax);
2071
- // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
2072
- if (smem > props_dev->max_theadgroup_memory_size/2) {
2073
- break;
2074
- }
2075
- nsgmax *= 2;
2076
- }
2077
- nsgmax /= 2;
2078
-
2079
- // simdgroups per threadgroup (a.k.a. warps)
2080
- //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
2081
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
2917
+ #define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
2082
2918
 
2083
2919
  int64_t nsg = 1;
2084
- while (nsg <= nsgt) {
2085
- nsg *= 2;
2086
- }
2087
- nsg /= 2;
2088
2920
 
2089
2921
  // workgroups
2090
2922
  // each workgroup handles nsg*nkpsg cache values
@@ -2097,7 +2929,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2097
2929
  } else {
2098
2930
  nwg = 32;
2099
2931
  nsg = 1;
2100
- while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
2932
+ while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
2101
2933
  nsg *= 2;
2102
2934
  }
2103
2935
  }
@@ -2120,6 +2952,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2120
2952
  /*.nb21 =*/ nb21,
2121
2953
  /*.nb22 =*/ nb22,
2122
2954
  /*.nb23 =*/ nb23,
2955
+ /*.ne31 =*/ ne31,
2123
2956
  /*.ne32 =*/ ne32,
2124
2957
  /*.ne33 =*/ ne33,
2125
2958
  /*.nb31 =*/ nb31,
@@ -2136,25 +2969,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2136
2969
  /*.logit_softcap =*/ logit_softcap,
2137
2970
  };
2138
2971
 
2139
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
2972
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2140
2973
 
2141
2974
  GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2142
2975
 
2143
2976
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2144
2977
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2145
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2146
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2147
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2148
- if (op->src[3]) {
2149
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
2150
- } else {
2151
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
2152
- }
2153
- if (op->src[4]) {
2154
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
2155
- } else {
2156
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
2157
- }
2978
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2979
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2980
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2981
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2982
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2158
2983
 
2159
2984
  const size_t smem = FATTN_SMEM(nsg);
2160
2985
 
@@ -2162,26 +2987,28 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2162
2987
  GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2163
2988
 
2164
2989
  if (nwg == 1) {
2990
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2991
+
2165
2992
  // using 1 workgroup -> write the result directly into dst
2166
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
2993
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2994
+ ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2167
2995
 
2168
2996
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2169
2997
 
2170
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2998
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2171
2999
  } else {
2172
3000
  // sanity checks
3001
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
3002
+
2173
3003
  GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2174
3004
  GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2175
3005
 
2176
- ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2177
-
2178
3006
  // write the results from each workgroup into a temp buffer
2179
- ggml_metal_buffer_id bid_tmp = bid_dst;
2180
- bid_tmp.offs += ggml_nbytes(op);
2181
- ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
3007
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
3008
+ ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2182
3009
 
2183
3010
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2184
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
3011
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2185
3012
 
2186
3013
  // sync the 2 kernels
2187
3014
  ggml_metal_op_concurrency_reset(ctx);
@@ -2194,7 +3021,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2194
3021
  nrows,
2195
3022
  };
2196
3023
 
2197
- ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
3024
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2198
3025
 
2199
3026
  ggml_metal_encoder_set_pipeline(enc, pipeline0);
2200
3027
  ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2233,8 +3060,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2233
3060
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2234
3061
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
2235
3062
 
2236
- bool bcast_row = false;
2237
-
2238
3063
  ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2239
3064
  ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2240
3065
  ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
@@ -2326,20 +3151,9 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2326
3151
  // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2327
3152
  bid_src1.offs = 0;
2328
3153
 
2329
- ggml_metal_pipeline_t pipeline = nullptr;
2330
-
2331
- if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2332
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
3154
+ struct ggml_metal_pipeline_with_params pipeline;
2333
3155
 
2334
- // src1 is a row
2335
- GGML_ASSERT(ne11 == 1);
2336
-
2337
- pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
2338
-
2339
- bcast_row = true;
2340
- } else {
2341
- pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
2342
- }
3156
+ pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
2343
3157
 
2344
3158
  if (n_fuse > 1) {
2345
3159
  bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
@@ -2353,20 +3167,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2353
3167
  }
2354
3168
  }
2355
3169
 
3170
+ if (pipeline.c4) {
3171
+ args.ne00 = ne00/4;
3172
+ args.ne10 = ne10/4;
3173
+ args.ne0 = ne0/4;
3174
+ }
3175
+
2356
3176
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2357
3177
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2358
3178
  ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2359
3179
  ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2360
3180
  ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
2361
3181
 
2362
- if (bcast_row) {
2363
- const int64_t n = ggml_nelements(op)/4;
2364
-
2365
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
3182
+ if (pipeline.cnt) {
3183
+ ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
2366
3184
  } else {
2367
- int nth = 32;
3185
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3186
+
3187
+ int nth = 1;
2368
3188
 
2369
- while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3189
+ while (2*nth < args.ne0 && nth < nth_max) {
2370
3190
  nth *= 2;
2371
3191
  }
2372
3192
 
@@ -2385,41 +3205,61 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2385
3205
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2386
3206
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2387
3207
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2388
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3208
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3209
+
3210
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3211
+
3212
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3213
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2389
3214
 
2390
3215
  float eps;
2391
3216
  memcpy(&eps, op->op_params, sizeof(float));
2392
3217
 
2393
- int nth = 32; // SIMD width
2394
-
2395
3218
  ggml_metal_kargs_l2_norm args = {
2396
- /*.ne00 =*/ ne00,
2397
- /*.ne00_4 =*/ ne00/4,
2398
- /*.nb01 =*/ nb01,
2399
- /*.eps =*/ eps,
3219
+ /*.ne00 =*/ ne00,
3220
+ /*.ne01 =*/ ne01,
3221
+ /*.ne02 =*/ ne02,
3222
+ /*.ne03 =*/ ne03,
3223
+ /*.nb00 =*/ nb00,
3224
+ /*.nb01 =*/ nb01,
3225
+ /*.nb02 =*/ nb02,
3226
+ /*.nb03 =*/ nb03,
3227
+ /*.ne0 =*/ ne0,
3228
+ /*.ne1 =*/ ne1,
3229
+ /*.ne2 =*/ ne2,
3230
+ /*.ne3 =*/ ne3,
3231
+ /*.nb0 =*/ nb0,
3232
+ /*.nb1 =*/ nb1,
3233
+ /*.nb2 =*/ nb2,
3234
+ /*.nb3 =*/ nb3,
3235
+ /*.eps =*/ eps,
2400
3236
  };
2401
3237
 
2402
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
3238
+ auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
3239
+
3240
+ if (pipeline.c4) {
3241
+ args.ne00 = ne00/4;
3242
+ args.ne0 = ne0/4;
3243
+ }
3244
+
3245
+ int nth = 32; // SIMD width
2403
3246
 
2404
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3247
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2405
3248
  nth *= 2;
2406
3249
  }
2407
3250
 
2408
3251
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2409
- nth = std::min(nth, ne00/4);
2410
3252
 
2411
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2412
-
2413
- const int64_t nrows = ggml_nrows(op->src[0]);
3253
+ const size_t smem = pipeline.smem;
2414
3254
 
2415
3255
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2416
3256
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2417
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2418
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3257
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3258
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
2419
3259
 
2420
3260
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2421
3261
 
2422
- ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3262
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2423
3263
 
2424
3264
  return 1;
2425
3265
  }
@@ -2433,7 +3273,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
2433
3273
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2434
3274
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2435
3275
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2436
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3276
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2437
3277
 
2438
3278
  const int32_t ngrp = ((const int32_t *) op->op_params)[0];
2439
3279
 
@@ -2451,7 +3291,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
2451
3291
  /*.eps =*/ eps,
2452
3292
  };
2453
3293
 
2454
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
3294
+ auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
2455
3295
 
2456
3296
  int nth = 32; // SIMD width
2457
3297
  //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
@@ -2461,7 +3301,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
2461
3301
  //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2462
3302
  //nth = std::min(nth, ne00/4);
2463
3303
 
2464
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
3304
+ const size_t smem = pipeline.smem;
2465
3305
 
2466
3306
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2467
3307
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2488,7 +3328,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
2488
3328
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2489
3329
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2490
3330
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2491
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3331
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2492
3332
 
2493
3333
  float eps;
2494
3334
  memcpy(&eps, op->op_params, sizeof(float));
@@ -2586,7 +3426,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
2586
3426
  }
2587
3427
  }
2588
3428
 
2589
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3429
+ auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
2590
3430
 
2591
3431
  int nth = 32; // SIMD width
2592
3432
 
@@ -2597,7 +3437,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
2597
3437
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2598
3438
  nth = std::min(nth, args.ne00_t);
2599
3439
 
2600
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
3440
+ const size_t smem = pipeline.smem;
2601
3441
 
2602
3442
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2603
3443
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2624,7 +3464,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
2624
3464
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2625
3465
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2626
3466
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2627
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3467
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2628
3468
 
2629
3469
  // make sure we have one or more position id(ne10) per token(ne02)
2630
3470
  GGML_ASSERT(ne10 % ne02 == 0);
@@ -2688,9 +3528,10 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
2688
3528
  /* sect_1 =*/ sect_1,
2689
3529
  /* sect_2 =*/ sect_2,
2690
3530
  /* sect_3 =*/ sect_3,
3531
+ /* src2 =*/ op->src[2] != nullptr,
2691
3532
  };
2692
3533
 
2693
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
3534
+ auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
2694
3535
 
2695
3536
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2696
3537
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2717,7 +3558,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
2717
3558
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2718
3559
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2719
3560
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2720
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3561
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2721
3562
 
2722
3563
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
2723
3564
  const int32_t s1 = ((const int32_t *)(op->op_params))[1];
@@ -2762,7 +3603,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
2762
3603
  /*.KHW =*/ KH * KW,
2763
3604
  };
2764
3605
 
2765
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
3606
+ auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
2766
3607
 
2767
3608
  GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2768
3609
 
@@ -2770,15 +3611,138 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
2770
3611
 
2771
3612
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2772
3613
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2773
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
2774
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3614
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3615
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3616
+
3617
+ ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3618
+
3619
+ return 1;
3620
+ }
3621
+
3622
+ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
3623
+ ggml_tensor * op = ctx->node(idx);
3624
+
3625
+ ggml_metal_library_t lib = ctx->lib;
3626
+ ggml_metal_encoder_t enc = ctx->enc;
3627
+
3628
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3629
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3630
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3631
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3632
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3633
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3634
+
3635
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
3636
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
3637
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
3638
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
3639
+
3640
+ const int32_t s0 = ((const int32_t *) op->op_params)[0];
3641
+ const int32_t s1 = ((const int32_t *) op->op_params)[1];
3642
+ const int32_t p0 = ((const int32_t *) op->op_params)[2];
3643
+ const int32_t p1 = ((const int32_t *) op->op_params)[3];
3644
+ const int32_t d0 = ((const int32_t *) op->op_params)[4];
3645
+ const int32_t d1 = ((const int32_t *) op->op_params)[5];
3646
+
3647
+ ggml_metal_kargs_conv_2d args = {
3648
+ /*.nb00 =*/ nb00,
3649
+ /*.nb01 =*/ nb01,
3650
+ /*.nb02 =*/ nb02,
3651
+ /*.nb03 =*/ nb03,
3652
+ /*.nb10 =*/ nb10,
3653
+ /*.nb11 =*/ nb11,
3654
+ /*.nb12 =*/ nb12,
3655
+ /*.nb13 =*/ nb13,
3656
+ /*.nb0 =*/ nb0,
3657
+ /*.nb1 =*/ nb1,
3658
+ /*.nb2 =*/ nb2,
3659
+ /*.nb3 =*/ nb3,
3660
+ /*.IW =*/ ne10,
3661
+ /*.IH =*/ ne11,
3662
+ /*.KW =*/ ne00,
3663
+ /*.KH =*/ ne01,
3664
+ /*.IC =*/ ne02,
3665
+ /*.OC =*/ ne03,
3666
+ /*.OW =*/ ne0,
3667
+ /*.OH =*/ ne1,
3668
+ /*.N =*/ ne3,
3669
+ /*.s0 =*/ s0,
3670
+ /*.s1 =*/ s1,
3671
+ /*.p0 =*/ p0,
3672
+ /*.p1 =*/ p1,
3673
+ /*.d0 =*/ d0,
3674
+ /*.d1 =*/ d1,
3675
+ };
3676
+
3677
+ auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
3678
+
3679
+ int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3680
+ nth = std::min(nth, 256);
3681
+ nth = std::max(nth, 1);
3682
+
3683
+ const uint64_t n_out = ggml_nelements(op);
3684
+
3685
+ uint64_t tg = (n_out + nth - 1)/nth;
3686
+ tg = std::max<uint64_t>(tg, 1);
3687
+ tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3688
+
3689
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3690
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3691
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3692
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3693
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3694
+
3695
+ ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3696
+
3697
+ return 1;
3698
+ }
3699
+
3700
+ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
3701
+ ggml_tensor * op = ctx->node(idx);
3702
+
3703
+ ggml_metal_library_t lib = ctx->lib;
3704
+ ggml_metal_encoder_t enc = ctx->enc;
3705
+
3706
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3707
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3708
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3709
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3710
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3711
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3712
+
3713
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3714
+
3715
+ const int32_t IC = op->src[1]->ne[1];
3716
+ const int32_t IL = op->src[1]->ne[0];
3717
+
3718
+ const int32_t K = op->src[0]->ne[0];
3719
+
3720
+ const int32_t OL = op->ne[0];
3721
+ const int32_t OC = op->ne[1];
3722
+
3723
+ ggml_metal_kargs_conv_transpose_1d args = {
3724
+ /*.IC =*/ IC,
3725
+ /*.IL =*/ IL,
3726
+ /*.K =*/ K,
3727
+ /*.s0 =*/ s0,
3728
+ /*.nb0 =*/ nb0,
3729
+ /*.nb1 =*/ nb1,
3730
+ };
3731
+
3732
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3733
+
3734
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3735
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3736
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3737
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3738
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
2775
3739
 
2776
- ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3740
+ ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
2777
3741
 
2778
3742
  return 1;
2779
3743
  }
2780
3744
 
2781
- int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
3745
+ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
2782
3746
  ggml_tensor * op = ctx->node(idx);
2783
3747
 
2784
3748
  ggml_metal_library_t lib = ctx->lib;
@@ -2789,28 +3753,35 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
2789
3753
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2790
3754
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2791
3755
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2792
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3756
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2793
3757
 
2794
3758
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
2795
3759
 
2796
- const int32_t IC = op->src[1]->ne[1];
2797
- const int32_t IL = op->src[1]->ne[0];
3760
+ const int32_t IC = op->src[1]->ne[2];
3761
+ const int32_t IH = op->src[1]->ne[1];
3762
+ const int32_t IW = op->src[1]->ne[0];
2798
3763
 
2799
- const int32_t K = op->src[0]->ne[0];
3764
+ const int32_t KH = op->src[0]->ne[1];
3765
+ const int32_t KW = op->src[0]->ne[0];
2800
3766
 
2801
- const int32_t OL = op->ne[0];
2802
- const int32_t OC = op->ne[1];
3767
+ const int32_t OW = op->ne[0];
3768
+ const int32_t OH = op->ne[1];
3769
+ const int32_t OC = op->ne[2];
2803
3770
 
2804
- ggml_metal_kargs_conv_transpose_1d args = {
3771
+ ggml_metal_kargs_conv_transpose_2d args = {
2805
3772
  /*.IC =*/ IC,
2806
- /*.IL =*/ IL,
2807
- /*.K =*/ K,
3773
+ /*.IH =*/ IH,
3774
+ /*.IW =*/ IW,
3775
+ /*.KH =*/ KH,
3776
+ /*.KW =*/ KW,
3777
+ /*.OC =*/ OC,
2808
3778
  /*.s0 =*/ s0,
2809
3779
  /*.nb0 =*/ nb0,
2810
3780
  /*.nb1 =*/ nb1,
3781
+ /*.nb2 =*/ nb2,
2811
3782
  };
2812
3783
 
2813
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3784
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
2814
3785
 
2815
3786
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2816
3787
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2818,7 +3789,11 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
2818
3789
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2819
3790
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
2820
3791
 
2821
- ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
3792
+ // Metal requires buffer size to be multiple of 16 bytes
3793
+ const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3794
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3795
+
3796
+ ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
2822
3797
 
2823
3798
  return 1;
2824
3799
  }
@@ -2832,37 +3807,48 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
2832
3807
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2833
3808
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2834
3809
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2835
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3810
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3811
+
3812
+ float sf0 = (float)ne0/op->src[0]->ne[0];
3813
+ float sf1 = (float)ne1/op->src[0]->ne[1];
3814
+ float sf2 = (float)ne2/op->src[0]->ne[2];
3815
+ float sf3 = (float)ne3/op->src[0]->ne[3];
3816
+
3817
+ const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
2836
3818
 
2837
- const float sf0 = (float)ne0/op->src[0]->ne[0];
2838
- const float sf1 = (float)ne1/op->src[0]->ne[1];
2839
- const float sf2 = (float)ne2/op->src[0]->ne[2];
2840
- const float sf3 = (float)ne3/op->src[0]->ne[3];
3819
+ float poffs = 0.5f;
3820
+
3821
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
3822
+ poffs = 0.0f;
3823
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
3824
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
3825
+ }
2841
3826
 
2842
3827
  ggml_metal_kargs_upscale args = {
2843
- /*.ne00 =*/ ne00,
2844
- /*.ne01 =*/ ne01,
2845
- /*.ne02 =*/ ne02,
2846
- /*.ne03 =*/ ne03,
2847
- /*.nb00 =*/ nb00,
2848
- /*.nb01 =*/ nb01,
2849
- /*.nb02 =*/ nb02,
2850
- /*.nb03 =*/ nb03,
2851
- /*.ne0 =*/ ne0,
2852
- /*.ne1 =*/ ne1,
2853
- /*.ne2 =*/ ne2,
2854
- /*.ne3 =*/ ne3,
2855
- /*.nb0 =*/ nb0,
2856
- /*.nb1 =*/ nb1,
2857
- /*.nb2 =*/ nb2,
2858
- /*.nb3 =*/ nb3,
2859
- /*.sf0 =*/ sf0,
2860
- /*.sf1 =*/ sf1,
2861
- /*.sf2 =*/ sf2,
2862
- /*.sf3 =*/ sf3
3828
+ /*.ne00 =*/ ne00,
3829
+ /*.ne01 =*/ ne01,
3830
+ /*.ne02 =*/ ne02,
3831
+ /*.ne03 =*/ ne03,
3832
+ /*.nb00 =*/ nb00,
3833
+ /*.nb01 =*/ nb01,
3834
+ /*.nb02 =*/ nb02,
3835
+ /*.nb03 =*/ nb03,
3836
+ /*.ne0 =*/ ne0,
3837
+ /*.ne1 =*/ ne1,
3838
+ /*.ne2 =*/ ne2,
3839
+ /*.ne3 =*/ ne3,
3840
+ /*.nb0 =*/ nb0,
3841
+ /*.nb1 =*/ nb1,
3842
+ /*.nb2 =*/ nb2,
3843
+ /*.nb3 =*/ nb3,
3844
+ /*.sf0 =*/ sf0,
3845
+ /*.sf1 =*/ sf1,
3846
+ /*.sf2 =*/ sf2,
3847
+ /*.sf3 =*/ sf3,
3848
+ /*.poffs =*/ poffs,
2863
3849
  };
2864
3850
 
2865
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
3851
+ auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
2866
3852
 
2867
3853
  const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
2868
3854
 
@@ -2885,7 +3871,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
2885
3871
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2886
3872
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2887
3873
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2888
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3874
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2889
3875
 
2890
3876
  ggml_metal_kargs_pad args = {
2891
3877
  /*.ne00 =*/ ne00,
@@ -2906,7 +3892,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
2906
3892
  /*.nb3 =*/ nb3
2907
3893
  };
2908
3894
 
2909
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
3895
+ auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
2910
3896
 
2911
3897
  const int nth = std::min(1024, ne0);
2912
3898
 
@@ -2929,7 +3915,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
2929
3915
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2930
3916
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2931
3917
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2932
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3918
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2933
3919
 
2934
3920
  ggml_metal_kargs_pad_reflect_1d args = {
2935
3921
  /*.ne00 =*/ ne00,
@@ -2952,7 +3938,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
2952
3938
  /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
2953
3939
  };
2954
3940
 
2955
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3941
+ auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
2956
3942
 
2957
3943
  const int nth = std::min(1024, ne0);
2958
3944
 
@@ -2973,7 +3959,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
2973
3959
  ggml_metal_encoder_t enc = ctx->enc;
2974
3960
 
2975
3961
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2976
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3962
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2977
3963
 
2978
3964
  float start;
2979
3965
  float step;
@@ -2989,13 +3975,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
2989
3975
 
2990
3976
  const int nth = std::min(1024, ne0);
2991
3977
 
2992
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
2993
-
2994
- //[encoder setComputePipelineState:pipeline];
2995
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2996
- //[encoder setBytes:&args length:sizeof(args) atIndex:1];
2997
-
2998
- //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3978
+ auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
2999
3979
 
3000
3980
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3001
3981
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3015,7 +3995,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3015
3995
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3016
3996
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3017
3997
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3018
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3998
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3019
3999
 
3020
4000
  const int dim = op->op_params[0];
3021
4001
  const int max_period = op->op_params[1];
@@ -3026,7 +4006,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3026
4006
  /*.max_period =*/ max_period,
3027
4007
  };
3028
4008
 
3029
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
4009
+ auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3030
4010
 
3031
4011
  const int nth = std::max(1, std::min(1024, dim/2));
3032
4012
 
@@ -3049,14 +4029,14 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3049
4029
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3050
4030
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3051
4031
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3052
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
4032
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3053
4033
 
3054
4034
  ggml_metal_kargs_argmax args = {
3055
4035
  /*.ne00 = */ ne00,
3056
4036
  /*.nb01 = */ nb01,
3057
4037
  };
3058
4038
 
3059
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
4039
+ auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
3060
4040
 
3061
4041
  const int64_t nrows = ggml_nrows(op->src[0]);
3062
4042
 
@@ -3065,7 +4045,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3065
4045
  nth *= 2;
3066
4046
  }
3067
4047
 
3068
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
4048
+ const size_t smem = pipeline.smem;
3069
4049
 
3070
4050
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3071
4051
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3085,74 +4065,397 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
3085
4065
  ggml_metal_library_t lib = ctx->lib;
3086
4066
  ggml_metal_encoder_t enc = ctx->enc;
3087
4067
 
4068
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
4069
+
3088
4070
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3089
4071
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3090
4072
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3091
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
4073
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4074
+
4075
+ auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3092
4076
 
3093
4077
  // bitonic sort requires the number of elements to be power of 2
3094
- int64_t ne00_padded = 1;
3095
- while (ne00_padded < ne00) {
3096
- ne00_padded *= 2;
4078
+ int nth = 1;
4079
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4080
+ nth *= 2;
3097
4081
  }
3098
4082
 
3099
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3100
-
3101
- const int64_t nrows = ggml_nrows(op->src[0]);
4083
+ const int npr = (ne00 + nth - 1)/nth;
3102
4084
 
3103
4085
  // Metal kernels require the buffer size to be multiple of 16 bytes
3104
4086
  // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3105
- const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
4087
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
4088
+
4089
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
4090
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
4091
+
4092
+ ggml_metal_buffer_id bid_tmp = bid_dst;
4093
+ bid_tmp.offs += ggml_nbytes(op);
4094
+
4095
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
4096
+ std::swap(bid_dst, bid_tmp);
4097
+ }
3106
4098
 
3107
4099
  ggml_metal_kargs_argsort args = {
3108
- /*.ncols =*/ ne00,
3109
- /*.ncols_pad =*/ ne00_padded
4100
+ /*.ne00 =*/ ne00,
4101
+ /*.ne01 =*/ ne01,
4102
+ /*.ne02 =*/ ne02,
4103
+ /*.ne03 =*/ ne03,
4104
+ /*.nb00 =*/ nb00,
4105
+ /*.nb01 =*/ nb01,
4106
+ /*.nb02 =*/ nb02,
4107
+ /*.nb03 =*/ nb03,
4108
+ /*.ne0 =*/ ne0,
4109
+ /*.ne1 =*/ ne1,
4110
+ /*.ne2 =*/ ne2,
4111
+ /*.ne3 =*/ ne3,
4112
+ /*.top_k =*/ nth,
3110
4113
  };
3111
4114
 
3112
4115
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3113
4116
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3114
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3115
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
4117
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
4118
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3116
4119
 
3117
4120
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3118
4121
 
3119
- ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
4122
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
4123
+
4124
+ auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
4125
+
4126
+ int len = nth;
4127
+
4128
+ while (len < ne00) {
4129
+ ggml_metal_op_concurrency_reset(ctx);
4130
+
4131
+ ggml_metal_kargs_argsort_merge args_merge = {
4132
+ /*.ne00 =*/ ne00,
4133
+ /*.ne01 =*/ ne01,
4134
+ /*.ne02 =*/ ne02,
4135
+ /*.ne03 =*/ ne03,
4136
+ /*.nb00 =*/ nb00,
4137
+ /*.nb01 =*/ nb01,
4138
+ /*.nb02 =*/ nb02,
4139
+ /*.nb03 =*/ nb03,
4140
+ /*.ne0 =*/ ne0,
4141
+ /*.ne1 =*/ ne1,
4142
+ /*.ne2 =*/ ne2,
4143
+ /*.ne3 =*/ ne3,
4144
+ /*.top_k =*/ ne00,
4145
+ /*.len =*/ len,
4146
+ };
4147
+
4148
+ // merges per row
4149
+ const int nm = (ne00 + 2*len - 1) / (2*len);
4150
+
4151
+ const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
4152
+
4153
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
4154
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
4155
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
4156
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
4157
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
4158
+
4159
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
4160
+
4161
+ std::swap(bid_dst, bid_tmp);
4162
+
4163
+ len <<= 1;
4164
+ }
3120
4165
 
3121
4166
  return 1;
3122
4167
  }
3123
4168
 
3124
- int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
4169
+ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3125
4170
  ggml_tensor * op = ctx->node(idx);
3126
4171
 
3127
4172
  ggml_metal_library_t lib = ctx->lib;
3128
4173
  ggml_metal_encoder_t enc = ctx->enc;
3129
4174
 
4175
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
4176
+
3130
4177
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3131
4178
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3132
4179
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3133
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
4180
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4181
+
4182
+ auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
4183
+
4184
+ // bitonic sort requires the number of elements to be power of 2
4185
+ int nth = 1;
4186
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4187
+ nth *= 2;
4188
+ }
4189
+
4190
+ // blocks per row
4191
+ const int npr = (ne00 + nth - 1)/nth;
4192
+
4193
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
4194
+
4195
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
4196
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
4197
+
4198
+ ggml_metal_buffer_id bid_tmp = bid_dst;
4199
+ bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
4200
+
4201
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
4202
+ std::swap(bid_dst, bid_tmp);
4203
+ }
4204
+
4205
+ const int top_k = ne0;
4206
+
4207
+ ggml_metal_kargs_argsort args = {
4208
+ /*.ne00 =*/ ne00,
4209
+ /*.ne01 =*/ ne01,
4210
+ /*.ne02 =*/ ne02,
4211
+ /*.ne03 =*/ ne03,
4212
+ /*.nb00 =*/ nb00,
4213
+ /*.nb01 =*/ nb01,
4214
+ /*.nb02 =*/ nb02,
4215
+ /*.nb03 =*/ nb03,
4216
+ /*.ne0 =*/ ne0,
4217
+ /*.ne1 =*/ ne1,
4218
+ /*.ne2 =*/ ne2,
4219
+ /*.ne3 =*/ ne3,
4220
+ /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
4221
+ };
4222
+
4223
+ if (npr > 1) {
4224
+ args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
4225
+ }
4226
+
4227
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4228
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
4229
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
4230
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
4231
+
4232
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4233
+
4234
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
4235
+
4236
+ auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
4237
+
4238
+ int len = args.top_k;
4239
+
4240
+ while (len < args.ne0) {
4241
+ ggml_metal_op_concurrency_reset(ctx);
4242
+
4243
+ // merges per row
4244
+ const int nm = (args.ne0 + 2*len - 1) / (2*len);
4245
+
4246
+ const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
4247
+
4248
+ ggml_metal_kargs_argsort_merge args_merge = {
4249
+ /*.ne00 =*/ ne00,
4250
+ /*.ne01 =*/ ne01,
4251
+ /*.ne02 =*/ ne02,
4252
+ /*.ne03 =*/ ne03,
4253
+ /*.nb00 =*/ nb00,
4254
+ /*.nb01 =*/ nb01,
4255
+ /*.nb02 =*/ nb02,
4256
+ /*.nb03 =*/ nb03,
4257
+ /*.ne0 =*/ args.ne0,
4258
+ /*.ne1 =*/ ne1,
4259
+ /*.ne2 =*/ ne2,
4260
+ /*.ne3 =*/ ne3,
4261
+ /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
4262
+ /*.len =*/ len,
4263
+ };
4264
+
4265
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
4266
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
4267
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
4268
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
4269
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
4270
+
4271
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
4272
+
4273
+ std::swap(bid_dst, bid_tmp);
4274
+
4275
+ len <<= 1;
4276
+ }
4277
+
4278
+ return 1;
4279
+ }
4280
+
4281
+ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
4282
+ ggml_tensor * op = ctx->node(idx);
3134
4283
 
3135
- float slope;
3136
- memcpy(&slope, op->op_params, sizeof(float));
4284
+ ggml_metal_library_t lib = ctx->lib;
4285
+ ggml_metal_encoder_t enc = ctx->enc;
4286
+
4287
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4288
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4289
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4290
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3137
4291
 
3138
- ggml_metal_kargs_leaky_relu args = {
3139
- /*.slope =*/ slope
4292
+ ggml_metal_kargs_tri args = {
4293
+ /*.ne00 =*/ ne00,
4294
+ /*.ne01 =*/ ne01,
4295
+ /*.ne02 =*/ ne02,
4296
+ /*.ne03 =*/ ne03,
4297
+ /*.nb00 =*/ nb00,
4298
+ /*.nb01 =*/ nb01,
4299
+ /*.nb02 =*/ nb02,
4300
+ /*.nb03 =*/ nb03,
4301
+ /*.ne0 =*/ ne0,
4302
+ /*.ne1 =*/ ne1,
4303
+ /*.ne2 =*/ ne2,
4304
+ /*.ne3 =*/ ne3,
4305
+ /*.nb0 =*/ nb0,
4306
+ /*.nb1 =*/ nb1,
4307
+ /*.nb2 =*/ nb2,
4308
+ /*.nb3 =*/ nb3,
3140
4309
  };
3141
4310
 
3142
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
4311
+ auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
3143
4312
 
3144
- int64_t n = ggml_nelements(op);
4313
+ int nth = 32; // SIMD width
3145
4314
 
3146
- if (n % 4 == 0) {
3147
- n /= 4;
4315
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4316
+ nth *= 2;
3148
4317
  }
3149
4318
 
4319
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4320
+ nth = std::min(nth, ne00);
4321
+
3150
4322
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3151
4323
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3152
4324
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3153
4325
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3154
4326
 
3155
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
4327
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4328
+
4329
+ return 1;
4330
+ }
4331
+
4332
+ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
4333
+ ggml_tensor * op = ctx->node(idx);
4334
+
4335
+ ggml_metal_library_t lib = ctx->lib;
4336
+ ggml_metal_encoder_t enc = ctx->enc;
4337
+
4338
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4339
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4340
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4341
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4342
+
4343
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
4344
+
4345
+ const int64_t np = ggml_nelements(op->src[0]);
4346
+ ggml_metal_kargs_opt_step_adamw args = {
4347
+ /*.np =*/ np,
4348
+ };
4349
+
4350
+ int ida = 0;
4351
+
4352
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4353
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4354
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4355
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4356
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4357
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
4358
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
4359
+
4360
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4361
+ const int64_t n = (np + nth - 1) / nth;
4362
+
4363
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4364
+
4365
+ return 1;
4366
+ }
4367
+
4368
+ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
4369
+ ggml_tensor * op = ctx->node(idx);
4370
+
4371
+ ggml_metal_library_t lib = ctx->lib;
4372
+ ggml_metal_encoder_t enc = ctx->enc;
4373
+
4374
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4375
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4376
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4377
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4378
+
4379
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
4380
+
4381
+ const int64_t np = ggml_nelements(op->src[0]);
4382
+ ggml_metal_kargs_opt_step_sgd args = {
4383
+ /*.np =*/ np,
4384
+ };
4385
+
4386
+ int ida = 0;
4387
+
4388
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4389
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4390
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4391
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4392
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4393
+
4394
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4395
+ const int64_t n = (np + nth - 1) / nth;
4396
+
4397
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4398
+
4399
+ return 1;
4400
+ }
4401
+
4402
+ int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
4403
+ ggml_tensor * op = ctx->node(idx);
4404
+
4405
+ ggml_metal_library_t lib = ctx->lib;
4406
+ ggml_metal_encoder_t enc = ctx->enc;
4407
+
4408
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
4409
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4410
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
4411
+
4412
+ {
4413
+ ggml_metal_kargs_memset args = { /*.val =*/ 0 };
4414
+
4415
+ auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
4416
+
4417
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4418
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4419
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
4420
+
4421
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
4422
+ }
4423
+
4424
+ ggml_metal_op_concurrency_reset(ctx);
4425
+
4426
+ {
4427
+ ggml_metal_kargs_count_equal args = {
4428
+ /*.ne00 =*/ ne00,
4429
+ /*.ne01 =*/ ne01,
4430
+ /*.ne02 =*/ ne02,
4431
+ /*.ne03 =*/ ne03,
4432
+ /*.nb00 =*/ nb00,
4433
+ /*.nb01 =*/ nb01,
4434
+ /*.nb02 =*/ nb02,
4435
+ /*.nb03 =*/ nb03,
4436
+ /*.nb10 =*/ nb10,
4437
+ /*.nb11 =*/ nb11,
4438
+ /*.nb12 =*/ nb12,
4439
+ /*.nb13 =*/ nb13,
4440
+ };
4441
+
4442
+ auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
4443
+
4444
+ const size_t smem = pipeline.smem;
4445
+
4446
+ const int nth = 32*pipeline.nsg;
4447
+
4448
+ GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4449
+
4450
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4451
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4452
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4453
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
4454
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
4455
+
4456
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4457
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4458
+ }
3156
4459
 
3157
4460
  return 1;
3158
4461
  }