whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -10,6 +10,8 @@
10
10
 
11
11
  #include <cassert>
12
12
  #include <algorithm>
13
+ #include <limits>
14
+ #include <cmath>
13
15
 
14
16
  static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
15
17
  if (!t) {
@@ -219,13 +221,17 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
219
221
  }
220
222
 
221
223
  if (ctx->debug_graph > 0) {
222
- GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
224
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
223
225
  }
224
226
  if (ctx->debug_graph > 1) {
225
227
  GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
226
228
  GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
227
229
  GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
228
230
  GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
231
+ GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
232
+ GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
233
+ GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
234
+ GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
229
235
  GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
230
236
  GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
231
237
 
@@ -237,6 +243,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
237
243
  GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
238
244
  ggml_is_contiguous(node->src[1]), node->src[1]->name);
239
245
  }
246
+ if (node->src[2]) {
247
+ GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
248
+ ggml_is_contiguous(node->src[2]), node->src[2]->name);
249
+ }
250
+ if (node->src[3]) {
251
+ GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
252
+ ggml_is_contiguous(node->src[3]), node->src[3]->name);
253
+ }
240
254
  if (node) {
241
255
  GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
242
256
  node->name);
@@ -272,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
272
286
  {
273
287
  n_fuse = ggml_metal_op_scale(ctx, idx);
274
288
  } break;
289
+ case GGML_OP_FILL:
290
+ {
291
+ n_fuse = ggml_metal_op_fill(ctx, idx);
292
+ } break;
275
293
  case GGML_OP_CLAMP:
276
294
  {
277
295
  n_fuse = ggml_metal_op_clamp(ctx, idx);
@@ -289,11 +307,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
289
307
  {
290
308
  n_fuse = ggml_metal_op_glu(ctx, idx);
291
309
  } break;
310
+ case GGML_OP_SUM:
311
+ {
312
+ n_fuse = ggml_metal_op_sum(ctx, idx);
313
+ } break;
292
314
  case GGML_OP_SUM_ROWS:
293
315
  case GGML_OP_MEAN:
294
316
  {
295
317
  n_fuse = ggml_metal_op_sum_rows(ctx, idx);
296
318
  } break;
319
+ case GGML_OP_CUMSUM:
320
+ {
321
+ n_fuse = ggml_metal_op_cumsum(ctx, idx);
322
+ } break;
297
323
  case GGML_OP_SOFT_MAX:
298
324
  {
299
325
  n_fuse = ggml_metal_op_soft_max(ctx, idx);
@@ -348,10 +374,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
348
374
  {
349
375
  n_fuse = ggml_metal_op_im2col(ctx, idx);
350
376
  } break;
377
+ case GGML_OP_CONV_2D:
378
+ {
379
+ n_fuse = ggml_metal_op_conv_2d(ctx, idx);
380
+ } break;
351
381
  case GGML_OP_CONV_TRANSPOSE_1D:
352
382
  {
353
383
  n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
354
384
  } break;
385
+ case GGML_OP_CONV_TRANSPOSE_2D:
386
+ {
387
+ n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
388
+ } break;
355
389
  case GGML_OP_UPSCALE:
356
390
  {
357
391
  n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -376,10 +410,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
376
410
  {
377
411
  n_fuse = ggml_metal_op_argsort(ctx, idx);
378
412
  } break;
413
+ case GGML_OP_TOP_K:
414
+ {
415
+ n_fuse = ggml_metal_op_top_k(ctx, idx);
416
+ } break;
379
417
  case GGML_OP_LEAKY_RELU:
380
418
  {
381
419
  n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
382
420
  } break;
421
+ case GGML_OP_TRI:
422
+ {
423
+ n_fuse = ggml_metal_op_tri(ctx, idx);
424
+ } break;
383
425
  case GGML_OP_FLASH_ATTN_EXT:
384
426
  {
385
427
  n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
@@ -398,7 +440,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
398
440
  {
399
441
  n_fuse = ggml_metal_op_argmax(ctx, idx);
400
442
  } break;
401
- default:
443
+ case GGML_OP_OPT_STEP_ADAMW:
444
+ {
445
+ n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
446
+ } break;
447
+ case GGML_OP_OPT_STEP_SGD:
448
+ {
449
+ n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
450
+ } break;
451
+ case GGML_OP_COUNT_EQUAL:
452
+ {
453
+ n_fuse = ggml_metal_op_count_equal(ctx, idx);
454
+ } break;
455
+ default:
402
456
  {
403
457
  GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
404
458
  GGML_ABORT("fatal error");
@@ -482,7 +536,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
482
536
  /*.dim =*/ dim,
483
537
  };
484
538
 
485
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
539
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
486
540
 
487
541
  ggml_metal_encoder_set_pipeline(enc, pipeline);
488
542
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -506,9 +560,9 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
506
560
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
507
561
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
508
562
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
509
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
563
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
510
564
 
511
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
565
+ auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
512
566
 
513
567
  ggml_metal_kargs_repeat args = {
514
568
  /*.ne00 =*/ ne00,
@@ -552,7 +606,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
552
606
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
553
607
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
554
608
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
555
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
609
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
556
610
 
557
611
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
558
612
  GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
@@ -574,9 +628,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
574
628
  // TODO: make a simpler cpy_bytes kernel
575
629
 
576
630
  //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
577
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
631
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
578
632
 
579
633
  ggml_metal_kargs_cpy args = {
634
+ /*.nk0 =*/ ne00,
580
635
  /*.ne00 =*/ ne00,
581
636
  /*.ne01 =*/ ne01,
582
637
  /*.ne02 =*/ ne02,
@@ -636,7 +691,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
636
691
  /*.o1 =*/ { 0 },
637
692
  };
638
693
 
639
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
694
+ auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
640
695
 
641
696
  ggml_metal_encoder_set_pipeline(enc, pipeline);
642
697
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -660,7 +715,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
660
715
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
661
716
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
662
717
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
663
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
718
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
664
719
 
665
720
  float scale;
666
721
  float bias;
@@ -678,7 +733,42 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
678
733
  n /= 4;
679
734
  }
680
735
 
681
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
736
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
737
+
738
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
739
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
740
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
741
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
742
+
743
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
744
+
745
+ return 1;
746
+ }
747
+
748
+ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
749
+ ggml_tensor * op = ctx->node(idx);
750
+
751
+ ggml_metal_library_t lib = ctx->lib;
752
+ ggml_metal_encoder_t enc = ctx->enc;
753
+
754
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
755
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
756
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
757
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
758
+
759
+ const float val = ggml_get_op_params_f32(op, 0);
760
+
761
+ ggml_metal_kargs_fill args = {
762
+ /*.val =*/ val
763
+ };
764
+
765
+ int64_t n = ggml_nelements(op);
766
+
767
+ if (n % 4 == 0) {
768
+ n /= 4;
769
+ }
770
+
771
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
682
772
 
683
773
  ggml_metal_encoder_set_pipeline(enc, pipeline);
684
774
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -699,7 +789,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
699
789
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
700
790
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
701
791
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
702
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
792
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
703
793
 
704
794
  float min;
705
795
  float max;
@@ -717,7 +807,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
717
807
  n /= 4;
718
808
  }
719
809
 
720
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
810
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
721
811
 
722
812
  ggml_metal_encoder_set_pipeline(enc, pipeline);
723
813
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -738,7 +828,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
738
828
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
739
829
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
740
830
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
741
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
831
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
742
832
 
743
833
  int64_t n = ggml_nelements(op);
744
834
 
@@ -746,7 +836,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
746
836
  n /= 4;
747
837
  }
748
838
 
749
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
839
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
750
840
 
751
841
  ggml_metal_encoder_set_pipeline(enc, pipeline);
752
842
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
@@ -768,13 +858,13 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
768
858
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
769
859
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
770
860
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
771
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
861
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
772
862
 
773
863
  if (op->src[1]) {
774
864
  GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
775
865
  }
776
866
 
777
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
867
+ auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
778
868
 
779
869
  const int32_t swp = ggml_get_op_params_i32(op, 1);
780
870
  const float alpha = ggml_get_op_params_f32(op, 2);
@@ -800,18 +890,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
800
890
 
801
891
  const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
802
892
 
803
- //[encoder setComputePipelineState:pipeline];
804
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
805
- //if (src1) {
806
- // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
807
- //} else {
808
- // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
809
- //}
810
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
811
- //[encoder setBytes:&args length:sizeof(args) atIndex:3];
812
-
813
- //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
814
-
815
893
  ggml_metal_encoder_set_pipeline(enc, pipeline);
816
894
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
817
895
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -827,6 +905,43 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
827
905
  return 1;
828
906
  }
829
907
 
908
+ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
909
+ ggml_tensor * op = ctx->node(idx);
910
+
911
+ ggml_metal_library_t lib = ctx->lib;
912
+ ggml_metal_encoder_t enc = ctx->enc;
913
+
914
+ const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
915
+
916
+ ggml_metal_kargs_sum args = {
917
+ /*.np =*/ n,
918
+ };
919
+
920
+ auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
921
+
922
+ int nth = 32; // SIMD width
923
+
924
+ while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
925
+ nth *= 2;
926
+ }
927
+
928
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
929
+ nth = std::min(nth, (int) n);
930
+
931
+ const int nsg = (nth + 31) / 32;
932
+
933
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
934
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
935
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
936
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
937
+
938
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
939
+
940
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
941
+
942
+ return 1;
943
+ }
944
+
830
945
  int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
831
946
  ggml_tensor * op = ctx->node(idx);
832
947
 
@@ -836,7 +951,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
836
951
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
837
952
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
838
953
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
839
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
954
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
840
955
 
841
956
  ggml_metal_kargs_sum_rows args = {
842
957
  /*.ne00 =*/ ne00,
@@ -857,7 +972,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
857
972
  /*.nb3 =*/ nb3,
858
973
  };
859
974
 
860
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
975
+ auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
861
976
 
862
977
  int nth = 32; // SIMD width
863
978
 
@@ -868,15 +983,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
868
983
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
869
984
  nth = std::min(nth, ne00);
870
985
 
871
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
872
-
873
- //[encoder setComputePipelineState:pipeline];
874
- //[encoder setBytes:&args length:sizeof(args) atIndex:0];
875
- //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
876
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
877
- //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
878
-
879
- //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
986
+ const size_t smem = pipeline.smem;
880
987
 
881
988
  ggml_metal_encoder_set_pipeline(enc, pipeline);
882
989
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -890,6 +997,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
890
997
  return 1;
891
998
  }
892
999
 
1000
+ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
1001
+ ggml_tensor * op = ctx->node(idx);
1002
+
1003
+ ggml_metal_library_t lib = ctx->lib;
1004
+ ggml_metal_encoder_t enc = ctx->enc;
1005
+
1006
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
1007
+
1008
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1009
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1010
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1011
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1012
+
1013
+ auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
1014
+
1015
+ int nth = 1;
1016
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
1017
+ nth *= 2;
1018
+ }
1019
+
1020
+ GGML_ASSERT(ne00 <= nth*nth);
1021
+
1022
+ const int64_t net0 = (ne00 + nth - 1) / nth;
1023
+ const int64_t net1 = ne01;
1024
+ const int64_t net2 = ne02;
1025
+ const int64_t net3 = ne03;
1026
+
1027
+ const uint64_t nbt0 = sizeof(float);
1028
+ const uint64_t nbt1 = net0*nbt0;
1029
+ const uint64_t nbt2 = net1*nbt1;
1030
+ const uint64_t nbt3 = net2*nbt2;
1031
+
1032
+ const size_t smem = GGML_PAD(32*sizeof(float), 16);
1033
+
1034
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
1035
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
1036
+
1037
+ ggml_metal_buffer_id bid_tmp = bid_dst;
1038
+ bid_tmp.offs += ggml_nbytes(op);
1039
+
1040
+ {
1041
+ ggml_metal_kargs_cumsum_blk args = {
1042
+ /*.ne00 =*/ ne00,
1043
+ /*.ne01 =*/ ne01,
1044
+ /*.ne02 =*/ ne02,
1045
+ /*.ne03 =*/ ne03,
1046
+ /*.nb00 =*/ nb00,
1047
+ /*.nb01 =*/ nb01,
1048
+ /*.nb02 =*/ nb02,
1049
+ /*.nb03 =*/ nb03,
1050
+ /*.net0 =*/ net0,
1051
+ /*.net1 =*/ net1,
1052
+ /*.net2 =*/ net2,
1053
+ /*.net3 =*/ net3,
1054
+ /*.nbt0 =*/ nbt0,
1055
+ /*.nbt1 =*/ nbt1,
1056
+ /*.nbt2 =*/ nbt2,
1057
+ /*.nbt3 =*/ nbt3,
1058
+ /*.outb =*/ ne00 > nth,
1059
+ };
1060
+
1061
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1062
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1063
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1064
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1065
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
1066
+
1067
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1068
+
1069
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1070
+ }
1071
+
1072
+ if (ne00 > nth) {
1073
+ ggml_metal_op_concurrency_reset(ctx);
1074
+
1075
+ {
1076
+ ggml_metal_kargs_cumsum_blk args = {
1077
+ /*.ne00 =*/ net0,
1078
+ /*.ne01 =*/ net1,
1079
+ /*.ne02 =*/ net2,
1080
+ /*.ne03 =*/ net3,
1081
+ /*.nb00 =*/ nbt0,
1082
+ /*.nb01 =*/ nbt1,
1083
+ /*.nb02 =*/ nbt2,
1084
+ /*.nb03 =*/ nbt3,
1085
+ /*.net0 =*/ net0,
1086
+ /*.net1 =*/ net1,
1087
+ /*.net2 =*/ net2,
1088
+ /*.net3 =*/ net3,
1089
+ /*.nbt0 =*/ nbt0,
1090
+ /*.nbt1 =*/ nbt1,
1091
+ /*.nbt2 =*/ nbt2,
1092
+ /*.nbt3 =*/ nbt3,
1093
+ /*.outb =*/ false,
1094
+ };
1095
+
1096
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1097
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1098
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1099
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1100
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
1101
+
1102
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1103
+
1104
+ ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1105
+ }
1106
+
1107
+ ggml_metal_op_concurrency_reset(ctx);
1108
+
1109
+ {
1110
+ auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1111
+
1112
+ ggml_metal_kargs_cumsum_add args = {
1113
+ /*.ne00 =*/ ne00,
1114
+ /*.ne01 =*/ ne01,
1115
+ /*.ne02 =*/ ne02,
1116
+ /*.ne03 =*/ ne03,
1117
+ /*.nb00 =*/ nb00,
1118
+ /*.nb01 =*/ nb01,
1119
+ /*.nb02 =*/ nb02,
1120
+ /*.nb03 =*/ nb03,
1121
+ /*.net0 =*/ net0,
1122
+ /*.net1 =*/ net1,
1123
+ /*.net2 =*/ net2,
1124
+ /*.net3 =*/ net3,
1125
+ /*.nbt0 =*/ nbt0,
1126
+ /*.nbt1 =*/ nbt1,
1127
+ /*.nbt2 =*/ nbt2,
1128
+ /*.nbt3 =*/ nbt3,
1129
+ };
1130
+
1131
+ ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1132
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1133
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1134
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1135
+
1136
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1137
+ }
1138
+ }
1139
+
1140
+ return 1;
1141
+ }
1142
+
893
1143
  int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
894
1144
  ggml_tensor * op = ctx->node(idx);
895
1145
 
@@ -901,28 +1151,36 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
901
1151
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
902
1152
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
903
1153
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
904
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1154
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
905
1155
 
906
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1156
+ auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
907
1157
 
908
1158
  ggml_metal_kargs_get_rows args = {
909
- /*.ne00 =*/ ne00,
910
- /*.nb01 =*/ nb01,
911
- /*.nb02 =*/ nb02,
912
- /*.ne10 =*/ ne10,
913
- /*.nb10 =*/ nb10,
914
- /*.nb11 =*/ nb11,
915
- /*.nb1 =*/ nb1,
916
- /*.nb2 =*/ nb2,
1159
+ /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
1160
+ /*.ne00 =*/ ne00,
1161
+ /*.nb01 =*/ nb01,
1162
+ /*.nb02 =*/ nb02,
1163
+ /*.nb03 =*/ nb03,
1164
+ /*.ne10 =*/ ne10,
1165
+ /*.nb10 =*/ nb10,
1166
+ /*.nb11 =*/ nb11,
1167
+ /*.nb12 =*/ nb12,
1168
+ /*.nb1 =*/ nb1,
1169
+ /*.nb2 =*/ nb2,
1170
+ /*.nb3 =*/ nb3,
917
1171
  };
918
1172
 
1173
+ const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1174
+
1175
+ const int nw0 = (args.ne00t + nth - 1)/nth;
1176
+
919
1177
  ggml_metal_encoder_set_pipeline(enc, pipeline);
920
1178
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
921
1179
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
922
1180
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
923
1181
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
924
1182
 
925
- ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
1183
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
926
1184
 
927
1185
  return 1;
928
1186
  }
@@ -938,9 +1196,9 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
938
1196
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
939
1197
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
940
1198
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
941
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1199
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
942
1200
 
943
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1201
+ auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
944
1202
 
945
1203
  const int32_t nk0 = ne0/ggml_blck_size(op->type);
946
1204
 
@@ -1002,7 +1260,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1002
1260
  GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1003
1261
  GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1004
1262
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1005
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1263
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1006
1264
 
1007
1265
  float scale;
1008
1266
  float max_bias;
@@ -1041,7 +1299,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1041
1299
  /*.n_head_log2 =*/ n_head_log2,
1042
1300
  };
1043
1301
 
1044
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1302
+ auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1045
1303
 
1046
1304
  int nth = 32; // SIMD width
1047
1305
 
@@ -1055,7 +1313,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1055
1313
  }
1056
1314
  }
1057
1315
 
1058
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1316
+ const size_t smem = pipeline.smem;
1059
1317
 
1060
1318
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1061
1319
  ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@@ -1090,7 +1348,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1090
1348
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1091
1349
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1092
1350
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1093
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1351
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1094
1352
 
1095
1353
  ggml_metal_kargs_ssm_conv args = {
1096
1354
  /*.ne00 =*/ ne00,
@@ -1111,15 +1369,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1111
1369
  /*.nb2 =*/ nb2,
1112
1370
  };
1113
1371
 
1114
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1372
+ // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1373
+ const bool use_batched = (ne1 > 1);
1115
1374
 
1116
- ggml_metal_encoder_set_pipeline(enc, pipeline);
1117
- ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1118
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1119
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1120
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1375
+ if (use_batched) {
1376
+ // Determine the smallest power of 2 that's >= ne1, but <= 256
1377
+ int BATCH_SIZE;
1378
+ if (ne1 > 128) BATCH_SIZE = 256;
1379
+ else if (ne1 > 64 ) BATCH_SIZE = 128;
1380
+ else if (ne1 > 32 ) BATCH_SIZE = 64;
1381
+ else if (ne1 > 16 ) BATCH_SIZE = 32;
1382
+ else if (ne1 > 8 ) BATCH_SIZE = 16;
1383
+ else if (ne1 > 4 ) BATCH_SIZE = 8;
1384
+ else BATCH_SIZE = 2;
1385
+
1386
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
1387
+
1388
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1389
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1390
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1391
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1392
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1393
+
1394
+ // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1395
+ // Each threadgroup has BATCH_SIZE threads, each handling one token
1396
+ const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1397
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1398
+ } else {
1399
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1400
+
1401
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1402
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1403
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1404
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1405
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1121
1406
 
1122
- ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1407
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1408
+ }
1123
1409
 
1124
1410
  return 1;
1125
1411
  }
@@ -1145,7 +1431,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1145
1431
  GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1146
1432
  GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1147
1433
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1148
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1434
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1149
1435
 
1150
1436
  const ggml_tensor * src3 = op->src[3];
1151
1437
  const ggml_tensor * src4 = op->src[4];
@@ -1172,26 +1458,37 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1172
1458
  /*.n_seq_tokens =*/ n_seq_tokens,
1173
1459
  /*.n_seqs =*/ n_seqs,
1174
1460
  /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
1461
+ /*.nb00 =*/ nb00,
1175
1462
  /*.nb01 =*/ nb01,
1176
1463
  /*.nb02 =*/ nb02,
1177
1464
  /*.nb03 =*/ nb03,
1465
+ /*.nb10 =*/ nb10,
1178
1466
  /*.nb11 =*/ nb11,
1179
1467
  /*.nb12 =*/ nb12,
1468
+ /*.ns12 =*/ nb12/nb10,
1180
1469
  /*.nb13 =*/ nb13,
1470
+ /*.nb20 =*/ nb20,
1181
1471
  /*.nb21 =*/ nb21,
1472
+ /*.ns21 =*/ nb21/nb20,
1182
1473
  /*.nb22 =*/ nb22,
1474
+ /*.ne30 =*/ ne30,
1183
1475
  /*.nb31 =*/ nb31,
1184
1476
  /*.nb41 =*/ nb41,
1185
1477
  /*.nb42 =*/ nb42,
1478
+ /*.ns42 =*/ nb42/nb40,
1186
1479
  /*.nb43 =*/ nb43,
1187
1480
  /*.nb51 =*/ nb51,
1188
1481
  /*.nb52 =*/ nb52,
1482
+ /*.ns52 =*/ nb52/nb50,
1189
1483
  /*.nb53 =*/ nb53,
1484
+ /*.nb0 =*/ nb0,
1190
1485
  };
1191
1486
 
1192
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1487
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1488
+
1489
+ GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1193
1490
 
1194
- const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
1491
+ const size_t smem = pipeline.smem;
1195
1492
 
1196
1493
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1197
1494
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -1204,15 +1501,9 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1204
1501
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
1205
1502
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
1206
1503
 
1207
- ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
1504
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1208
1505
 
1209
- if (ne30 == 1) {
1210
- // Mamba-2
1211
- ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1212
- } else {
1213
- GGML_ASSERT(d_inner == 1);
1214
- ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
1215
- }
1506
+ ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1216
1507
 
1217
1508
  return 1;
1218
1509
  }
@@ -1226,14 +1517,14 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1226
1517
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1227
1518
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1228
1519
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1229
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1520
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1230
1521
 
1231
1522
  const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1232
1523
  const int64_t T = op->src[0]->ne[2];
1233
1524
  const int64_t C = op->ne[0];
1234
1525
  const int64_t H = op->src[0]->ne[1];
1235
1526
 
1236
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1527
+ auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1237
1528
 
1238
1529
  int ida = 0;
1239
1530
 
@@ -1267,32 +1558,29 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1267
1558
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1268
1559
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1269
1560
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1270
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1561
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1271
1562
 
1272
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1563
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1273
1564
 
1274
1565
  GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
1275
1566
 
1276
- // TODO: support
1277
- //const int32_t nk00 = ne00/ggml_blck_size(op->type);
1278
- const int32_t nk00 = ne00;
1279
-
1280
- int nth = 32; // SIMD width
1281
-
1282
- while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1283
- nth *= 2;
1567
+ int64_t nk0 = ne00;
1568
+ if (ggml_is_quantized(op->src[0]->type)) {
1569
+ nk0 = ne00/16;
1570
+ } else if (ggml_is_quantized(op->type)) {
1571
+ nk0 = ne00/ggml_blck_size(op->type);
1284
1572
  }
1285
1573
 
1286
- nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1574
+ int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1287
1575
 
1288
1576
  // when rows are small, we can batch them together in a single threadgroup
1289
1577
  int nrptg = 1;
1290
1578
 
1291
1579
  // TODO: relax this constraint in the future
1292
1580
  if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
1293
- if (nth > nk00) {
1294
- nrptg = (nth + nk00 - 1)/nk00;
1295
- nth = nk00;
1581
+ if (nth > nk0) {
1582
+ nrptg = (nth + nk0 - 1)/nk0;
1583
+ nth = nk0;
1296
1584
 
1297
1585
  if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1298
1586
  nrptg--;
@@ -1300,10 +1588,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1300
1588
  }
1301
1589
  }
1302
1590
 
1303
- nth = std::min(nth, nk00);
1591
+ nth = std::min<int>(nth, nk0);
1304
1592
 
1305
1593
  ggml_metal_kargs_cpy args = {
1306
- /*.ne00 =*/ nk00,
1594
+ /*.nk0 =*/ nk0,
1595
+ /*.ne00 =*/ ne00,
1307
1596
  /*.ne01 =*/ ne01,
1308
1597
  /*.ne02 =*/ ne02,
1309
1598
  /*.ne03 =*/ ne03,
@@ -1321,12 +1610,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1321
1610
  /*.nb3 =*/ nb3,
1322
1611
  };
1323
1612
 
1613
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1614
+
1324
1615
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1325
1616
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1326
1617
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1327
1618
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1328
1619
 
1329
- ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
1620
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1330
1621
 
1331
1622
  return 1;
1332
1623
  }
@@ -1340,7 +1631,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1340
1631
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1341
1632
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1342
1633
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1343
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1634
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1344
1635
 
1345
1636
  const int32_t * opts = op->op_params;
1346
1637
  ggml_op_pool op_pool = (ggml_op_pool) opts[0];
@@ -1376,7 +1667,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1376
1667
  /* .np = */ np
1377
1668
  };
1378
1669
 
1379
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1670
+ auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1380
1671
 
1381
1672
  const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1382
1673
  const int ntg = (np + nth - 1) / nth;
@@ -1404,7 +1695,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1404
1695
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1405
1696
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1406
1697
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1407
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1698
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1408
1699
 
1409
1700
  GGML_ASSERT(ne00 == ne10);
1410
1701
 
@@ -1485,7 +1776,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1485
1776
  GGML_ABORT("unsupported ne11");
1486
1777
  };
1487
1778
 
1488
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1779
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1489
1780
 
1490
1781
  ggml_metal_kargs_mul_mv_ext args = {
1491
1782
  /*.ne00 =*/ ne00,
@@ -1520,9 +1811,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1520
1811
  !ggml_is_transposed(op->src[1]) &&
1521
1812
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1522
1813
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1523
- props_dev->has_simdgroup_mm && ne00 >= 64 &&
1524
- (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
1525
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1814
+ props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1815
+ //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1526
1816
 
1527
1817
  // some Metal matrix data types require aligned pointers
1528
1818
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1533,7 +1823,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1533
1823
  // default: break;
1534
1824
  //}
1535
1825
 
1536
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
1826
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
1537
1827
 
1538
1828
  ggml_metal_kargs_mul_mm args = {
1539
1829
  /*.ne00 =*/ ne00,
@@ -1558,18 +1848,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1558
1848
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1559
1849
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1560
1850
 
1561
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1851
+ const size_t smem = pipeline.smem;
1562
1852
 
1563
1853
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1564
1854
  ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1565
1855
  } else {
1566
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1856
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1567
1857
 
1568
- const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1569
- const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1570
- const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
1858
+ const int nr0 = pipeline.nr0;
1859
+ const int nr1 = pipeline.nr1;
1860
+ const int nsg = pipeline.nsg;
1571
1861
 
1572
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1862
+ const size_t smem = pipeline.smem;
1573
1863
 
1574
1864
  ggml_metal_kargs_mul_mv args = {
1575
1865
  /*.ne00 =*/ ne00,
@@ -1646,7 +1936,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1646
1936
  GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1647
1937
  GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1648
1938
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1649
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1939
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1650
1940
 
1651
1941
  // src2 = ids
1652
1942
  GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
@@ -1700,9 +1990,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1700
1990
  nb21,
1701
1991
  };
1702
1992
 
1703
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
1993
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
1704
1994
 
1705
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
1995
+ const size_t smem = pipeline.smem;
1706
1996
 
1707
1997
  GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1708
1998
 
@@ -1723,7 +2013,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1723
2013
  ggml_metal_op_concurrency_reset(ctx);
1724
2014
 
1725
2015
  {
1726
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
2016
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
1727
2017
 
1728
2018
  ggml_metal_kargs_mul_mm_id args = {
1729
2019
  /*.ne00 =*/ ne00,
@@ -1752,20 +2042,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1752
2042
  ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
1753
2043
  ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
1754
2044
 
1755
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2045
+ const size_t smem = pipeline.smem;
1756
2046
 
1757
2047
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1758
2048
 
1759
2049
  ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
1760
2050
  }
1761
2051
  } else {
1762
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
2052
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
1763
2053
 
1764
- const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
1765
- const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
1766
- const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
2054
+ const int nr0 = pipeline.nr0;
2055
+ const int nr1 = pipeline.nr1;
2056
+ const int nsg = pipeline.nsg;
1767
2057
 
1768
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2058
+ const size_t smem = pipeline.smem;
1769
2059
 
1770
2060
  ggml_metal_kargs_mul_mv_id args = {
1771
2061
  /*.nei0 =*/ ne20,
@@ -1849,7 +2139,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
1849
2139
  /*.nb21 =*/ nb21,
1850
2140
  };
1851
2141
 
1852
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
2142
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
1853
2143
 
1854
2144
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1855
2145
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -1875,20 +2165,118 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
1875
2165
  return (ne01 < 20) && (ne00 % 32 == 0);
1876
2166
  }
1877
2167
 
2168
+ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
2169
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2170
+
2171
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2172
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2173
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2174
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2175
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2176
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2177
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2178
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2179
+
2180
+ size_t res = 0;
2181
+
2182
+ const bool has_mask = op->src[3] != nullptr;
2183
+
2184
+ // note: the non-vec kernel requires more extra memory, so always reserve for it
2185
+ GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
2186
+
2187
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2188
+ if (false) {
2189
+ // note: always reserve the padding space to avoid graph reallocations
2190
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2191
+ const bool has_kvpad = true;
2192
+
2193
+ if (has_kvpad) {
2194
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
2195
+ nb11*ne12*ne13 +
2196
+ nb21*ne22*ne23 +
2197
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2198
+ }
2199
+ } else {
2200
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2201
+ const bool has_kvpad = true;
2202
+
2203
+ if (has_kvpad) {
2204
+ res += OP_FLASH_ATTN_EXT_NCPSG*(
2205
+ nb11*ne12*ne13 +
2206
+ nb21*ne22*ne23 +
2207
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2208
+ }
2209
+ }
2210
+
2211
+ return res;
2212
+ }
2213
+
2214
+ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2215
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2216
+
2217
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2218
+ //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2219
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2220
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2221
+ //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2222
+ //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2223
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2224
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2225
+
2226
+ size_t res = 0;
2227
+
2228
+ const bool has_mask = op->src[3] != nullptr;
2229
+
2230
+ if (!has_mask) {
2231
+ return res;
2232
+ }
2233
+
2234
+ const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
2235
+
2236
+ // this optimization is not useful for the vector kernels
2237
+ // note: always reserve the blk buffer to avoid graph reallocations
2238
+ //if (is_vec) {
2239
+ // return res;
2240
+ //}
2241
+
2242
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2243
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2244
+
2245
+ const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2246
+ const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2247
+
2248
+ res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2249
+
2250
+ return res;
2251
+ }
2252
+
1878
2253
  size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
1879
2254
  assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1880
2255
 
1881
- const int64_t nwg = 32;
2256
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2257
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2258
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2259
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2260
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2261
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2262
+ //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2263
+ //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2264
+
2265
+ size_t res = 0;
2266
+
2267
+ // note: always reserve the temp buffer to avoid graph reallocations
2268
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2269
+ if (true) {
2270
+ const int64_t nwg = 32;
2271
+ const int64_t ne01_max = std::min(ne01, 32);
1882
2272
 
1883
- const int64_t ne01 = op->src[0]->ne[1];
1884
- const int64_t ne02 = op->src[0]->ne[2];
1885
- const int64_t ne03 = op->src[0]->ne[3];
1886
- const int64_t ne20 = op->src[2]->ne[0];
2273
+ // temp buffer for writing the results from each workgroup
2274
+ // - ne20: the size of the Value head
2275
+ // - + 2: the S and M values for each intermediate result
2276
+ res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2277
+ }
1887
2278
 
1888
- // temp buffer for writing the results from each workgroup
1889
- // - ne20: the size of the Value head
1890
- // - + 2: the S and M values for each intermediate result
1891
- return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
2279
+ return res;
1892
2280
  }
1893
2281
 
1894
2282
  int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
@@ -1910,8 +2298,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
1910
2298
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1911
2299
  GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
1912
2300
 
1913
- GGML_ASSERT(ne00 % 4 == 0);
1914
- GGML_ASSERT(ne11 % 32 == 0);
2301
+ GGML_ASSERT(ne00 % 4 == 0);
1915
2302
 
1916
2303
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
1917
2304
  GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@@ -1921,8 +2308,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
1921
2308
  GGML_ASSERT(ne12 == ne22);
1922
2309
 
1923
2310
  GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
1924
- GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
1925
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2311
+ GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2312
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
1926
2313
 
1927
2314
  float scale;
1928
2315
  float max_bias;
@@ -1949,15 +2336,107 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
1949
2336
 
1950
2337
  GGML_ASSERT(ne01 < 65536);
1951
2338
 
1952
- if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
1953
- // half8x8 kernel
1954
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
1955
- const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
2339
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2340
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2341
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2342
+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2343
+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2344
+
2345
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2346
+
2347
+ ggml_metal_buffer_id bid_pad = bid_dst;
2348
+ bid_pad.offs += ggml_nbytes(op);
2349
+
2350
+ ggml_metal_buffer_id bid_blk = bid_pad;
2351
+ bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2352
+
2353
+ ggml_metal_buffer_id bid_tmp = bid_blk;
2354
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
2355
+
2356
+ if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
2357
+ // half8x8 kernel
2358
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2359
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
1956
2360
 
1957
2361
  GGML_ASSERT(nqptg <= 32);
1958
2362
  GGML_ASSERT(nqptg % 8 == 0);
1959
2363
  GGML_ASSERT(ncpsg % 32 == 0);
1960
2364
 
2365
+ bool need_sync = false;
2366
+
2367
+ const bool has_kvpad = ne11 % ncpsg != 0;
2368
+
2369
+ if (has_kvpad) {
2370
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2371
+
2372
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
2373
+ /*.ne11 =*/ne11,
2374
+ /*.ne_12_2 =*/ne12,
2375
+ /*.ne_12_3 =*/ne13,
2376
+ /*.nb11 =*/nb11,
2377
+ /*.nb12 =*/nb12,
2378
+ /*.nb13 =*/nb13,
2379
+ /*.nb21 =*/nb21,
2380
+ /*.nb22 =*/nb22,
2381
+ /*.nb23 =*/nb23,
2382
+ /*.ne31 =*/ne31,
2383
+ /*.ne32 =*/ne32,
2384
+ /*.ne33 =*/ne33,
2385
+ /*.nb31 =*/nb31,
2386
+ /*.nb32 =*/nb32,
2387
+ /*.nb33 =*/nb33,
2388
+ };
2389
+
2390
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2391
+
2392
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2393
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2394
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2395
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2396
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2397
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2398
+
2399
+ assert(ne12 == ne22);
2400
+ assert(ne13 == ne23);
2401
+
2402
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2403
+
2404
+ need_sync = true;
2405
+ }
2406
+
2407
+ if (has_mask) {
2408
+ assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2409
+
2410
+ ggml_metal_kargs_flash_attn_ext_blk args0 = {
2411
+ /*.ne01 =*/ ne01,
2412
+ /*.ne30 =*/ ne30,
2413
+ /*.ne31 =*/ ne31,
2414
+ /*.ne32 =*/ ne32,
2415
+ /*.ne33 =*/ ne33,
2416
+ /*.nb31 =*/ nb31,
2417
+ /*.nb32 =*/ nb32,
2418
+ /*.nb33 =*/ nb33,
2419
+ };
2420
+
2421
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2422
+
2423
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2424
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2425
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2426
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2427
+
2428
+ const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2429
+ const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2430
+
2431
+ ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2432
+
2433
+ need_sync = true;
2434
+ }
2435
+
2436
+ if (need_sync) {
2437
+ ggml_metal_op_concurrency_reset(ctx);
2438
+ }
2439
+
1961
2440
  const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
1962
2441
 
1963
2442
  // 2*(2*ncpsg)
@@ -2007,6 +2486,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2007
2486
  /*.nb21 =*/ nb21,
2008
2487
  /*.nb22 =*/ nb22,
2009
2488
  /*.nb23 =*/ nb23,
2489
+ /*.ne31 =*/ ne31,
2010
2490
  /*.ne32 =*/ ne32,
2011
2491
  /*.ne33 =*/ ne33,
2012
2492
  /*.nb31 =*/ nb31,
@@ -2023,24 +2503,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2023
2503
  /*.logit_softcap =*/ logit_softcap,
2024
2504
  };
2025
2505
 
2026
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
2506
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2027
2507
 
2028
2508
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2029
2509
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2030
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2031
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2032
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2033
- if (op->src[3]) {
2034
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
2035
- } else {
2036
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
2037
- }
2038
- if (op->src[4]) {
2039
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
2040
- } else {
2041
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
2042
- }
2043
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
2510
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2511
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2512
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2513
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2514
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2515
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2516
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2517
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2044
2518
 
2045
2519
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2046
2520
 
@@ -2048,14 +2522,60 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2048
2522
  #undef FATTN_SMEM
2049
2523
  } else {
2050
2524
  // half4x4 kernel
2051
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2052
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2053
- const int64_t nkpsg = 1*ncpsg;
2525
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2526
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2527
+ const int nkpsg = 1*ncpsg;
2054
2528
 
2055
2529
  GGML_ASSERT(nqptg <= 32);
2056
2530
  GGML_ASSERT(nqptg % 1 == 0);
2057
2531
  GGML_ASSERT(ncpsg % 32 == 0);
2058
2532
 
2533
+ bool need_sync = false;
2534
+
2535
+ const bool has_kvpad = ne11 % ncpsg != 0;
2536
+
2537
+ if (has_kvpad) {
2538
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2539
+
2540
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
2541
+ /*.ne11 =*/ne11,
2542
+ /*.ne_12_2 =*/ne12,
2543
+ /*.ne_12_3 =*/ne13,
2544
+ /*.nb11 =*/nb11,
2545
+ /*.nb12 =*/nb12,
2546
+ /*.nb13 =*/nb13,
2547
+ /*.nb21 =*/nb21,
2548
+ /*.nb22 =*/nb22,
2549
+ /*.nb23 =*/nb23,
2550
+ /*.ne31 =*/ne31,
2551
+ /*.ne32 =*/ne32,
2552
+ /*.ne33 =*/ne33,
2553
+ /*.nb31 =*/nb31,
2554
+ /*.nb32 =*/nb32,
2555
+ /*.nb33 =*/nb33,
2556
+ };
2557
+
2558
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2559
+
2560
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2561
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2562
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2563
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2564
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2565
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2566
+
2567
+ assert(ne12 == ne22);
2568
+ assert(ne13 == ne23);
2569
+
2570
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2571
+
2572
+ need_sync = true;
2573
+ }
2574
+
2575
+ if (need_sync) {
2576
+ ggml_metal_op_concurrency_reset(ctx);
2577
+ }
2578
+
2059
2579
  // ne00 + 2*ncpsg*(nsg)
2060
2580
  // for each query, we load it as f16 in shared memory (ne00)
2061
2581
  // and store the soft_max values and the mask
@@ -2120,6 +2640,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2120
2640
  /*.nb21 =*/ nb21,
2121
2641
  /*.nb22 =*/ nb22,
2122
2642
  /*.nb23 =*/ nb23,
2643
+ /*.ne31 =*/ ne31,
2123
2644
  /*.ne32 =*/ ne32,
2124
2645
  /*.ne33 =*/ ne33,
2125
2646
  /*.nb31 =*/ nb31,
@@ -2136,25 +2657,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2136
2657
  /*.logit_softcap =*/ logit_softcap,
2137
2658
  };
2138
2659
 
2139
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
2660
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2140
2661
 
2141
2662
  GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2142
2663
 
2143
2664
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2144
2665
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2145
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2146
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2147
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2148
- if (op->src[3]) {
2149
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
2150
- } else {
2151
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
2152
- }
2153
- if (op->src[4]) {
2154
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
2155
- } else {
2156
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
2157
- }
2666
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2667
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2668
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2669
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2670
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2158
2671
 
2159
2672
  const size_t smem = FATTN_SMEM(nsg);
2160
2673
 
@@ -2162,23 +2675,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2162
2675
  GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2163
2676
 
2164
2677
  if (nwg == 1) {
2678
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2679
+
2165
2680
  // using 1 workgroup -> write the result directly into dst
2166
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
2681
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2682
+ ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2167
2683
 
2168
2684
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2169
2685
 
2170
2686
  ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2171
2687
  } else {
2172
2688
  // sanity checks
2689
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2690
+
2173
2691
  GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2174
2692
  GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2175
2693
 
2176
- ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2177
-
2178
2694
  // write the results from each workgroup into a temp buffer
2179
- ggml_metal_buffer_id bid_tmp = bid_dst;
2180
- bid_tmp.offs += ggml_nbytes(op);
2181
- ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
2695
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2696
+ ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2182
2697
 
2183
2698
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2184
2699
  ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
@@ -2194,7 +2709,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2194
2709
  nrows,
2195
2710
  };
2196
2711
 
2197
- ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2712
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2198
2713
 
2199
2714
  ggml_metal_encoder_set_pipeline(enc, pipeline0);
2200
2715
  ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
@@ -2326,7 +2841,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2326
2841
  // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2327
2842
  bid_src1.offs = 0;
2328
2843
 
2329
- ggml_metal_pipeline_t pipeline = nullptr;
2844
+ struct ggml_metal_pipeline_with_params pipeline;
2330
2845
 
2331
2846
  if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2332
2847
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -2385,7 +2900,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2385
2900
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2386
2901
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2387
2902
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2388
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2903
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2389
2904
 
2390
2905
  float eps;
2391
2906
  memcpy(&eps, op->op_params, sizeof(float));
@@ -2399,7 +2914,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2399
2914
  /*.eps =*/ eps,
2400
2915
  };
2401
2916
 
2402
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
2917
+ auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
2403
2918
 
2404
2919
  while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2405
2920
  nth *= 2;
@@ -2408,7 +2923,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2408
2923
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2409
2924
  nth = std::min(nth, ne00/4);
2410
2925
 
2411
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2926
+ const size_t smem = pipeline.smem;
2412
2927
 
2413
2928
  const int64_t nrows = ggml_nrows(op->src[0]);
2414
2929
 
@@ -2433,7 +2948,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
2433
2948
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2434
2949
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2435
2950
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2436
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
2951
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2437
2952
 
2438
2953
  const int32_t ngrp = ((const int32_t *) op->op_params)[0];
2439
2954
 
@@ -2451,7 +2966,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
2451
2966
  /*.eps =*/ eps,
2452
2967
  };
2453
2968
 
2454
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
2969
+ auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
2455
2970
 
2456
2971
  int nth = 32; // SIMD width
2457
2972
  //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
@@ -2461,7 +2976,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
2461
2976
  //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2462
2977
  //nth = std::min(nth, ne00/4);
2463
2978
 
2464
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
2979
+ const size_t smem = pipeline.smem;
2465
2980
 
2466
2981
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2467
2982
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2488,7 +3003,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
2488
3003
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2489
3004
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2490
3005
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2491
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3006
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2492
3007
 
2493
3008
  float eps;
2494
3009
  memcpy(&eps, op->op_params, sizeof(float));
@@ -2586,7 +3101,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
2586
3101
  }
2587
3102
  }
2588
3103
 
2589
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3104
+ auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
2590
3105
 
2591
3106
  int nth = 32; // SIMD width
2592
3107
 
@@ -2597,7 +3112,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
2597
3112
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2598
3113
  nth = std::min(nth, args.ne00_t);
2599
3114
 
2600
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
3115
+ const size_t smem = pipeline.smem;
2601
3116
 
2602
3117
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2603
3118
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2624,7 +3139,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
2624
3139
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2625
3140
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2626
3141
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2627
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3142
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2628
3143
 
2629
3144
  // make sure we have one or more position id(ne10) per token(ne02)
2630
3145
  GGML_ASSERT(ne10 % ne02 == 0);
@@ -2688,9 +3203,10 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
2688
3203
  /* sect_1 =*/ sect_1,
2689
3204
  /* sect_2 =*/ sect_2,
2690
3205
  /* sect_3 =*/ sect_3,
3206
+ /* src2 =*/ op->src[2] != nullptr,
2691
3207
  };
2692
3208
 
2693
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
3209
+ auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
2694
3210
 
2695
3211
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2696
3212
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2717,7 +3233,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
2717
3233
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2718
3234
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2719
3235
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2720
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3236
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2721
3237
 
2722
3238
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
2723
3239
  const int32_t s1 = ((const int32_t *)(op->op_params))[1];
@@ -2762,7 +3278,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
2762
3278
  /*.KHW =*/ KH * KW,
2763
3279
  };
2764
3280
 
2765
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
3281
+ auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
2766
3282
 
2767
3283
  GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2768
3284
 
@@ -2778,6 +3294,84 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
2778
3294
  return 1;
2779
3295
  }
2780
3296
 
3297
+ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
3298
+ ggml_tensor * op = ctx->node(idx);
3299
+
3300
+ ggml_metal_library_t lib = ctx->lib;
3301
+ ggml_metal_encoder_t enc = ctx->enc;
3302
+
3303
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3304
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3305
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3306
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3307
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3308
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3309
+
3310
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
3311
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
3312
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
3313
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
3314
+
3315
+ const int32_t s0 = ((const int32_t *) op->op_params)[0];
3316
+ const int32_t s1 = ((const int32_t *) op->op_params)[1];
3317
+ const int32_t p0 = ((const int32_t *) op->op_params)[2];
3318
+ const int32_t p1 = ((const int32_t *) op->op_params)[3];
3319
+ const int32_t d0 = ((const int32_t *) op->op_params)[4];
3320
+ const int32_t d1 = ((const int32_t *) op->op_params)[5];
3321
+
3322
+ ggml_metal_kargs_conv_2d args = {
3323
+ /*.nb00 =*/ nb00,
3324
+ /*.nb01 =*/ nb01,
3325
+ /*.nb02 =*/ nb02,
3326
+ /*.nb03 =*/ nb03,
3327
+ /*.nb10 =*/ nb10,
3328
+ /*.nb11 =*/ nb11,
3329
+ /*.nb12 =*/ nb12,
3330
+ /*.nb13 =*/ nb13,
3331
+ /*.nb0 =*/ nb0,
3332
+ /*.nb1 =*/ nb1,
3333
+ /*.nb2 =*/ nb2,
3334
+ /*.nb3 =*/ nb3,
3335
+ /*.IW =*/ ne10,
3336
+ /*.IH =*/ ne11,
3337
+ /*.KW =*/ ne00,
3338
+ /*.KH =*/ ne01,
3339
+ /*.IC =*/ ne02,
3340
+ /*.OC =*/ ne03,
3341
+ /*.OW =*/ ne0,
3342
+ /*.OH =*/ ne1,
3343
+ /*.N =*/ ne3,
3344
+ /*.s0 =*/ s0,
3345
+ /*.s1 =*/ s1,
3346
+ /*.p0 =*/ p0,
3347
+ /*.p1 =*/ p1,
3348
+ /*.d0 =*/ d0,
3349
+ /*.d1 =*/ d1,
3350
+ };
3351
+
3352
+ auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
3353
+
3354
+ int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3355
+ nth = std::min(nth, 256);
3356
+ nth = std::max(nth, 1);
3357
+
3358
+ const uint64_t n_out = ggml_nelements(op);
3359
+
3360
+ uint64_t tg = (n_out + nth - 1)/nth;
3361
+ tg = std::max<uint64_t>(tg, 1);
3362
+ tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3363
+
3364
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3365
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3366
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3367
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3368
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3369
+
3370
+ ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3371
+
3372
+ return 1;
3373
+ }
3374
+
2781
3375
  int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
2782
3376
  ggml_tensor * op = ctx->node(idx);
2783
3377
 
@@ -2789,7 +3383,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
2789
3383
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2790
3384
  GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2791
3385
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2792
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3386
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2793
3387
 
2794
3388
  const int32_t s0 = ((const int32_t *)(op->op_params))[0];
2795
3389
 
@@ -2810,7 +3404,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
2810
3404
  /*.nb1 =*/ nb1,
2811
3405
  };
2812
3406
 
2813
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3407
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
2814
3408
 
2815
3409
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2816
3410
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2823,6 +3417,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
2823
3417
  return 1;
2824
3418
  }
2825
3419
 
3420
+ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
3421
+ ggml_tensor * op = ctx->node(idx);
3422
+
3423
+ ggml_metal_library_t lib = ctx->lib;
3424
+ ggml_metal_encoder_t enc = ctx->enc;
3425
+
3426
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3427
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3428
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3429
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3430
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3431
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3432
+
3433
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3434
+
3435
+ const int32_t IC = op->src[1]->ne[2];
3436
+ const int32_t IH = op->src[1]->ne[1];
3437
+ const int32_t IW = op->src[1]->ne[0];
3438
+
3439
+ const int32_t KH = op->src[0]->ne[1];
3440
+ const int32_t KW = op->src[0]->ne[0];
3441
+
3442
+ const int32_t OW = op->ne[0];
3443
+ const int32_t OH = op->ne[1];
3444
+ const int32_t OC = op->ne[2];
3445
+
3446
+ ggml_metal_kargs_conv_transpose_2d args = {
3447
+ /*.IC =*/ IC,
3448
+ /*.IH =*/ IH,
3449
+ /*.IW =*/ IW,
3450
+ /*.KH =*/ KH,
3451
+ /*.KW =*/ KW,
3452
+ /*.OC =*/ OC,
3453
+ /*.s0 =*/ s0,
3454
+ /*.nb0 =*/ nb0,
3455
+ /*.nb1 =*/ nb1,
3456
+ /*.nb2 =*/ nb2,
3457
+ };
3458
+
3459
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3460
+
3461
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3462
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3463
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3464
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3465
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3466
+
3467
+ // Metal requires buffer size to be multiple of 16 bytes
3468
+ const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3469
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3470
+
3471
+ ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3472
+
3473
+ return 1;
3474
+ }
3475
+
2826
3476
  int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
2827
3477
  ggml_tensor * op = ctx->node(idx);
2828
3478
 
@@ -2832,7 +3482,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
2832
3482
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2833
3483
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2834
3484
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2835
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3485
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2836
3486
 
2837
3487
  const float sf0 = (float)ne0/op->src[0]->ne[0];
2838
3488
  const float sf1 = (float)ne1/op->src[0]->ne[1];
@@ -2862,7 +3512,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
2862
3512
  /*.sf3 =*/ sf3
2863
3513
  };
2864
3514
 
2865
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
3515
+ auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
2866
3516
 
2867
3517
  const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
2868
3518
 
@@ -2885,7 +3535,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
2885
3535
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2886
3536
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2887
3537
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2888
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3538
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2889
3539
 
2890
3540
  ggml_metal_kargs_pad args = {
2891
3541
  /*.ne00 =*/ ne00,
@@ -2906,7 +3556,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
2906
3556
  /*.nb3 =*/ nb3
2907
3557
  };
2908
3558
 
2909
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
3559
+ auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
2910
3560
 
2911
3561
  const int nth = std::min(1024, ne0);
2912
3562
 
@@ -2929,7 +3579,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
2929
3579
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2930
3580
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2931
3581
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2932
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3582
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2933
3583
 
2934
3584
  ggml_metal_kargs_pad_reflect_1d args = {
2935
3585
  /*.ne00 =*/ ne00,
@@ -2952,7 +3602,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
2952
3602
  /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
2953
3603
  };
2954
3604
 
2955
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3605
+ auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
2956
3606
 
2957
3607
  const int nth = std::min(1024, ne0);
2958
3608
 
@@ -2973,7 +3623,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
2973
3623
  ggml_metal_encoder_t enc = ctx->enc;
2974
3624
 
2975
3625
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2976
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3626
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2977
3627
 
2978
3628
  float start;
2979
3629
  float step;
@@ -2989,13 +3639,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
2989
3639
 
2990
3640
  const int nth = std::min(1024, ne0);
2991
3641
 
2992
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
2993
-
2994
- //[encoder setComputePipelineState:pipeline];
2995
- //[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2996
- //[encoder setBytes:&args length:sizeof(args) atIndex:1];
2997
-
2998
- //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3642
+ auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
2999
3643
 
3000
3644
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3001
3645
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3015,7 +3659,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3015
3659
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3016
3660
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3017
3661
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3018
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3662
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3019
3663
 
3020
3664
  const int dim = op->op_params[0];
3021
3665
  const int max_period = op->op_params[1];
@@ -3026,7 +3670,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3026
3670
  /*.max_period =*/ max_period,
3027
3671
  };
3028
3672
 
3029
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3673
+ auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3030
3674
 
3031
3675
  const int nth = std::max(1, std::min(1024, dim/2));
3032
3676
 
@@ -3049,14 +3693,14 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3049
3693
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3050
3694
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3051
3695
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3052
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3696
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3053
3697
 
3054
3698
  ggml_metal_kargs_argmax args = {
3055
3699
  /*.ne00 = */ ne00,
3056
3700
  /*.nb01 = */ nb01,
3057
3701
  };
3058
3702
 
3059
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
3703
+ auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
3060
3704
 
3061
3705
  const int64_t nrows = ggml_nrows(op->src[0]);
3062
3706
 
@@ -3065,7 +3709,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3065
3709
  nth *= 2;
3066
3710
  }
3067
3711
 
3068
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
3712
+ const size_t smem = pipeline.smem;
3069
3713
 
3070
3714
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3071
3715
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3085,38 +3729,215 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
3085
3729
  ggml_metal_library_t lib = ctx->lib;
3086
3730
  ggml_metal_encoder_t enc = ctx->enc;
3087
3731
 
3732
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3733
+
3088
3734
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3089
3735
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3090
3736
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3091
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3737
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3738
+
3739
+ auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3092
3740
 
3093
3741
  // bitonic sort requires the number of elements to be power of 2
3094
- int64_t ne00_padded = 1;
3095
- while (ne00_padded < ne00) {
3096
- ne00_padded *= 2;
3742
+ int nth = 1;
3743
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3744
+ nth *= 2;
3097
3745
  }
3098
3746
 
3099
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3100
-
3101
- const int64_t nrows = ggml_nrows(op->src[0]);
3747
+ const int npr = (ne00 + nth - 1)/nth;
3102
3748
 
3103
3749
  // Metal kernels require the buffer size to be multiple of 16 bytes
3104
3750
  // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3105
- const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
3751
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3752
+
3753
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3754
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3755
+
3756
+ ggml_metal_buffer_id bid_tmp = bid_dst;
3757
+ bid_tmp.offs += ggml_nbytes(op);
3758
+
3759
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3760
+ std::swap(bid_dst, bid_tmp);
3761
+ }
3106
3762
 
3107
3763
  ggml_metal_kargs_argsort args = {
3108
- /*.ncols =*/ ne00,
3109
- /*.ncols_pad =*/ ne00_padded
3764
+ /*.ne00 =*/ ne00,
3765
+ /*.ne01 =*/ ne01,
3766
+ /*.ne02 =*/ ne02,
3767
+ /*.ne03 =*/ ne03,
3768
+ /*.nb00 =*/ nb00,
3769
+ /*.nb01 =*/ nb01,
3770
+ /*.nb02 =*/ nb02,
3771
+ /*.nb03 =*/ nb03,
3772
+ /*.ne0 =*/ ne0,
3773
+ /*.ne1 =*/ ne1,
3774
+ /*.ne2 =*/ ne2,
3775
+ /*.ne3 =*/ ne3,
3776
+ /*.top_k =*/ nth,
3110
3777
  };
3111
3778
 
3112
3779
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3113
3780
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3114
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3115
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3781
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3782
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3116
3783
 
3117
3784
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3118
3785
 
3119
- ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
3786
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3787
+
3788
+ auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3789
+
3790
+ int len = nth;
3791
+
3792
+ while (len < ne00) {
3793
+ ggml_metal_op_concurrency_reset(ctx);
3794
+
3795
+ ggml_metal_kargs_argsort_merge args_merge = {
3796
+ /*.ne00 =*/ ne00,
3797
+ /*.ne01 =*/ ne01,
3798
+ /*.ne02 =*/ ne02,
3799
+ /*.ne03 =*/ ne03,
3800
+ /*.nb00 =*/ nb00,
3801
+ /*.nb01 =*/ nb01,
3802
+ /*.nb02 =*/ nb02,
3803
+ /*.nb03 =*/ nb03,
3804
+ /*.ne0 =*/ ne0,
3805
+ /*.ne1 =*/ ne1,
3806
+ /*.ne2 =*/ ne2,
3807
+ /*.ne3 =*/ ne3,
3808
+ /*.top_k =*/ ne00,
3809
+ /*.len =*/ len,
3810
+ };
3811
+
3812
+ // merges per row
3813
+ const int nm = (ne00 + 2*len - 1) / (2*len);
3814
+
3815
+ const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3816
+
3817
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3818
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3819
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3820
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3821
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3822
+
3823
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3824
+
3825
+ std::swap(bid_dst, bid_tmp);
3826
+
3827
+ len <<= 1;
3828
+ }
3829
+
3830
+ return 1;
3831
+ }
3832
+
3833
+ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3834
+ ggml_tensor * op = ctx->node(idx);
3835
+
3836
+ ggml_metal_library_t lib = ctx->lib;
3837
+ ggml_metal_encoder_t enc = ctx->enc;
3838
+
3839
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3840
+
3841
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3842
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3843
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3844
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3845
+
3846
+ auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
3847
+
3848
+ // bitonic sort requires the number of elements to be power of 2
3849
+ int nth = 1;
3850
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3851
+ nth *= 2;
3852
+ }
3853
+
3854
+ // blocks per row
3855
+ const int npr = (ne00 + nth - 1)/nth;
3856
+
3857
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3858
+
3859
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3860
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3861
+
3862
+ ggml_metal_buffer_id bid_tmp = bid_dst;
3863
+ bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
3864
+
3865
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3866
+ std::swap(bid_dst, bid_tmp);
3867
+ }
3868
+
3869
+ const int top_k = ne0;
3870
+
3871
+ ggml_metal_kargs_argsort args = {
3872
+ /*.ne00 =*/ ne00,
3873
+ /*.ne01 =*/ ne01,
3874
+ /*.ne02 =*/ ne02,
3875
+ /*.ne03 =*/ ne03,
3876
+ /*.nb00 =*/ nb00,
3877
+ /*.nb01 =*/ nb01,
3878
+ /*.nb02 =*/ nb02,
3879
+ /*.nb03 =*/ nb03,
3880
+ /*.ne0 =*/ ne0,
3881
+ /*.ne1 =*/ ne1,
3882
+ /*.ne2 =*/ ne2,
3883
+ /*.ne3 =*/ ne3,
3884
+ /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
3885
+ };
3886
+
3887
+ if (npr > 1) {
3888
+ args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
3889
+ }
3890
+
3891
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3892
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3893
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3894
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3895
+
3896
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3897
+
3898
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3899
+
3900
+ auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
3901
+
3902
+ int len = args.top_k;
3903
+
3904
+ while (len < args.ne0) {
3905
+ ggml_metal_op_concurrency_reset(ctx);
3906
+
3907
+ // merges per row
3908
+ const int nm = (args.ne0 + 2*len - 1) / (2*len);
3909
+
3910
+ const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
3911
+
3912
+ ggml_metal_kargs_argsort_merge args_merge = {
3913
+ /*.ne00 =*/ ne00,
3914
+ /*.ne01 =*/ ne01,
3915
+ /*.ne02 =*/ ne02,
3916
+ /*.ne03 =*/ ne03,
3917
+ /*.nb00 =*/ nb00,
3918
+ /*.nb01 =*/ nb01,
3919
+ /*.nb02 =*/ nb02,
3920
+ /*.nb03 =*/ nb03,
3921
+ /*.ne0 =*/ args.ne0,
3922
+ /*.ne1 =*/ ne1,
3923
+ /*.ne2 =*/ ne2,
3924
+ /*.ne3 =*/ ne3,
3925
+ /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
3926
+ /*.len =*/ len,
3927
+ };
3928
+
3929
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3930
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3931
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3932
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3933
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3934
+
3935
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3936
+
3937
+ std::swap(bid_dst, bid_tmp);
3938
+
3939
+ len <<= 1;
3940
+ }
3120
3941
 
3121
3942
  return 1;
3122
3943
  }
@@ -3130,7 +3951,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
3130
3951
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3131
3952
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3132
3953
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3133
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3954
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3134
3955
 
3135
3956
  float slope;
3136
3957
  memcpy(&slope, op->op_params, sizeof(float));
@@ -3139,7 +3960,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
3139
3960
  /*.slope =*/ slope
3140
3961
  };
3141
3962
 
3142
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
3963
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
3143
3964
 
3144
3965
  int64_t n = ggml_nelements(op);
3145
3966
 
@@ -3156,3 +3977,185 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
3156
3977
 
3157
3978
  return 1;
3158
3979
  }
3980
+
3981
+ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
3982
+ ggml_tensor * op = ctx->node(idx);
3983
+
3984
+ ggml_metal_library_t lib = ctx->lib;
3985
+ ggml_metal_encoder_t enc = ctx->enc;
3986
+
3987
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3988
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3989
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3990
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3991
+
3992
+ ggml_metal_kargs_tri args = {
3993
+ /*.ne00 =*/ ne00,
3994
+ /*.ne01 =*/ ne01,
3995
+ /*.ne02 =*/ ne02,
3996
+ /*.ne03 =*/ ne03,
3997
+ /*.nb00 =*/ nb00,
3998
+ /*.nb01 =*/ nb01,
3999
+ /*.nb02 =*/ nb02,
4000
+ /*.nb03 =*/ nb03,
4001
+ /*.ne0 =*/ ne0,
4002
+ /*.ne1 =*/ ne1,
4003
+ /*.ne2 =*/ ne2,
4004
+ /*.ne3 =*/ ne3,
4005
+ /*.nb0 =*/ nb0,
4006
+ /*.nb1 =*/ nb1,
4007
+ /*.nb2 =*/ nb2,
4008
+ /*.nb3 =*/ nb3,
4009
+ };
4010
+
4011
+ auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
4012
+
4013
+ int nth = 32; // SIMD width
4014
+
4015
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4016
+ nth *= 2;
4017
+ }
4018
+
4019
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4020
+ nth = std::min(nth, ne00);
4021
+
4022
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4023
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
4024
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4025
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
4026
+
4027
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4028
+
4029
+ return 1;
4030
+ }
4031
+
4032
+ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
4033
+ ggml_tensor * op = ctx->node(idx);
4034
+
4035
+ ggml_metal_library_t lib = ctx->lib;
4036
+ ggml_metal_encoder_t enc = ctx->enc;
4037
+
4038
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4039
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4040
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4041
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4042
+
4043
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
4044
+
4045
+ const int64_t np = ggml_nelements(op->src[0]);
4046
+ ggml_metal_kargs_opt_step_adamw args = {
4047
+ /*.np =*/ np,
4048
+ };
4049
+
4050
+ int ida = 0;
4051
+
4052
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4053
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4054
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4055
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4056
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4057
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
4058
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
4059
+
4060
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4061
+ const int64_t n = (np + nth - 1) / nth;
4062
+
4063
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4064
+
4065
+ return 1;
4066
+ }
4067
+
4068
+ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
4069
+ ggml_tensor * op = ctx->node(idx);
4070
+
4071
+ ggml_metal_library_t lib = ctx->lib;
4072
+ ggml_metal_encoder_t enc = ctx->enc;
4073
+
4074
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4075
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4076
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4077
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4078
+
4079
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
4080
+
4081
+ const int64_t np = ggml_nelements(op->src[0]);
4082
+ ggml_metal_kargs_opt_step_sgd args = {
4083
+ /*.np =*/ np,
4084
+ };
4085
+
4086
+ int ida = 0;
4087
+
4088
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4089
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4090
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4091
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4092
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4093
+
4094
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4095
+ const int64_t n = (np + nth - 1) / nth;
4096
+
4097
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4098
+
4099
+ return 1;
4100
+ }
4101
+
4102
+ int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
4103
+ ggml_tensor * op = ctx->node(idx);
4104
+
4105
+ ggml_metal_library_t lib = ctx->lib;
4106
+ ggml_metal_encoder_t enc = ctx->enc;
4107
+
4108
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
4109
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4110
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
4111
+
4112
+ {
4113
+ ggml_metal_kargs_memset args = { /*.val =*/ 0 };
4114
+
4115
+ auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
4116
+
4117
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4118
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4119
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
4120
+
4121
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
4122
+ }
4123
+
4124
+ ggml_metal_op_concurrency_reset(ctx);
4125
+
4126
+ {
4127
+ ggml_metal_kargs_count_equal args = {
4128
+ /*.ne00 =*/ ne00,
4129
+ /*.ne01 =*/ ne01,
4130
+ /*.ne02 =*/ ne02,
4131
+ /*.ne03 =*/ ne03,
4132
+ /*.nb00 =*/ nb00,
4133
+ /*.nb01 =*/ nb01,
4134
+ /*.nb02 =*/ nb02,
4135
+ /*.nb03 =*/ nb03,
4136
+ /*.nb10 =*/ nb10,
4137
+ /*.nb11 =*/ nb11,
4138
+ /*.nb12 =*/ nb12,
4139
+ /*.nb13 =*/ nb13,
4140
+ };
4141
+
4142
+ auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
4143
+
4144
+ const size_t smem = pipeline.smem;
4145
+
4146
+ const int nth = 32*pipeline.nsg;
4147
+
4148
+ GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4149
+
4150
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4151
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4152
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4153
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
4154
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
4155
+
4156
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4157
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4158
+ }
4159
+
4160
+ return 1;
4161
+ }