whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -0,0 +1,4161 @@
1
+ #include "ggml-metal-ops.h"
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-backend-impl.h"
6
+
7
+ #include "ggml-metal-impl.h"
8
+ #include "ggml-metal-common.h"
9
+ #include "ggml-metal-device.h"
10
+
11
+ #include <cassert>
12
+ #include <algorithm>
13
+ #include <limits>
14
+ #include <cmath>
15
+
16
+ static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
17
+ if (!t) {
18
+ return { nullptr, 0 };
19
+ }
20
+
21
+ ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
22
+
23
+ ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context;
24
+
25
+ return ggml_metal_buffer_get_id(ctx, t);
26
+ }
27
+
28
+ struct ggml_metal_op {
29
+ ggml_metal_op(
30
+ ggml_metal_device_t dev,
31
+ ggml_metal_cmd_buf_t cmd_buf,
32
+ ggml_cgraph * gf,
33
+ int idx_start,
34
+ int idx_end,
35
+ bool use_fusion,
36
+ bool use_concurrency,
37
+ bool use_capture,
38
+ int debug_graph,
39
+ int debug_fusion) {
40
+ this->dev = dev;
41
+ this->lib = ggml_metal_device_get_library(dev);
42
+ this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
43
+ this->mem_ranges = ggml_mem_ranges_init(debug_graph);
44
+ this->idx_start = idx_start;
45
+ this->idx_end = idx_end;
46
+ this->use_fusion = use_fusion;
47
+ this->use_concurrency = use_concurrency;
48
+ this->use_capture = use_capture;
49
+ this->debug_graph = debug_graph;
50
+ this->debug_fusion = debug_fusion;
51
+ this->gf = gf;
52
+
53
+ idxs.reserve(gf->n_nodes);
54
+
55
+ // filter empty nodes
56
+ // TODO: this can be removed when the allocator starts filtering them earlier
57
+ // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
58
+ for (int i = idx_start; i < idx_end; i++) {
59
+ if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
60
+ idxs.push_back(i);
61
+ }
62
+ }
63
+ }
64
+
65
+ ~ggml_metal_op() {
66
+ ggml_metal_encoder_end_encoding(this->enc);
67
+ ggml_metal_encoder_free(this->enc);
68
+ ggml_mem_ranges_free(this->mem_ranges);
69
+ }
70
+
71
+ int n_nodes() const {
72
+ return idxs.size();
73
+ }
74
+
75
+ ggml_tensor * node(int i) const {
76
+ assert(i >= 0 && i < (int) idxs.size());
77
+ return ggml_graph_node(gf, idxs[i]);
78
+ }
79
+
80
+ bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
81
+ assert(use_fusion);
82
+ assert(i0 >= 0 && i0 < n_nodes());
83
+
84
+ if (i0 + n_ops > n_nodes()) {
85
+ return false;
86
+ }
87
+
88
+ return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
89
+ }
90
+
91
+ ggml_metal_device_t dev;
92
+ ggml_metal_library_t lib;
93
+ ggml_metal_encoder_t enc;
94
+ ggml_mem_ranges_t mem_ranges;
95
+
96
+ bool use_fusion;
97
+ bool use_concurrency;
98
+ bool use_capture;
99
+
100
+ int debug_graph;
101
+ int debug_fusion;
102
+
103
+ private:
104
+ ggml_cgraph * gf;
105
+
106
+ int idx_start;
107
+ int idx_end;
108
+
109
+ // non-empty node indices
110
+ std::vector<int> idxs;
111
+ };
112
+
113
+ ggml_metal_op_t ggml_metal_op_init(
114
+ ggml_metal_device_t dev,
115
+ ggml_metal_cmd_buf_t cmd_buf,
116
+ ggml_cgraph * gf,
117
+ int idx_start,
118
+ int idx_end,
119
+ bool use_fusion,
120
+ bool use_concurrency,
121
+ bool use_capture,
122
+ int debug_graph,
123
+ int debug_fusion) {
124
+ ggml_metal_op_t res = new ggml_metal_op(
125
+ dev,
126
+ cmd_buf,
127
+ gf,
128
+ idx_start,
129
+ idx_end,
130
+ use_fusion,
131
+ use_concurrency,
132
+ use_capture,
133
+ debug_graph,
134
+ debug_fusion);
135
+
136
+ return res;
137
+ }
138
+
139
+ void ggml_metal_op_free(ggml_metal_op_t ctx) {
140
+ delete ctx;
141
+ }
142
+
143
+ int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
144
+ return ctx->n_nodes();
145
+ }
146
+
147
+ static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
148
+ if (!ctx->mem_ranges) {
149
+ return true;
150
+ }
151
+
152
+ ggml_metal_encoder_memory_barrier(ctx->enc);
153
+
154
+ ggml_mem_ranges_reset(ctx->mem_ranges);
155
+
156
+ return true;
157
+ }
158
+
159
+ static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) {
160
+ if (!ctx->mem_ranges) {
161
+ return false;
162
+ }
163
+
164
+ return ggml_mem_ranges_check(ctx->mem_ranges, node);
165
+ }
166
+
167
+ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) {
168
+ if (!ctx->mem_ranges) {
169
+ return true;
170
+ }
171
+
172
+ return ggml_mem_ranges_add(ctx->mem_ranges, node);
173
+ }
174
+
175
+ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
176
+ struct ggml_tensor * node = ctx->node(idx);
177
+
178
+ //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
179
+
180
+ if (ggml_is_empty(node)) {
181
+ return 1;
182
+ }
183
+
184
+ switch (node->op) {
185
+ case GGML_OP_NONE:
186
+ case GGML_OP_RESHAPE:
187
+ case GGML_OP_VIEW:
188
+ case GGML_OP_TRANSPOSE:
189
+ case GGML_OP_PERMUTE:
190
+ {
191
+ // noop -> next node
192
+ if (ctx->debug_graph > 0) {
193
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
194
+ }
195
+ } return 1;
196
+ default:
197
+ {
198
+ } break;
199
+ }
200
+
201
+ if (!ggml_metal_device_supports_op(ctx->dev, node)) {
202
+ GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node));
203
+ GGML_ABORT("unsupported op");
204
+ }
205
+
206
+ int n_fuse = 1;
207
+
208
+ // check if the current node can run concurrently with other nodes before it
209
+ // the condition is that:
210
+ // - the current node cannot write to any previous src or dst ranges
211
+ // - the current node cannot read from any previous dst ranges
212
+ //
213
+ // if the condition is not satisfied, we put a memory barrier and clear all ranges
214
+ // otherwise, we add the new ranges to the encoding context and process the node concurrently
215
+ //
216
+ {
217
+ const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node);
218
+
219
+ if (!is_concurrent) {
220
+ ggml_metal_op_concurrency_reset(ctx);
221
+ }
222
+
223
+ if (ctx->debug_graph > 0) {
224
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
225
+ }
226
+ if (ctx->debug_graph > 1) {
227
+ GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
228
+ GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
229
+ GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
230
+ GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
231
+ GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
232
+ GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
233
+ GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
234
+ GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
235
+ GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
236
+ GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
237
+
238
+ if (node->src[0]) {
239
+ GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
240
+ ggml_is_contiguous(node->src[0]), node->src[0]->name);
241
+ }
242
+ if (node->src[1]) {
243
+ GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
244
+ ggml_is_contiguous(node->src[1]), node->src[1]->name);
245
+ }
246
+ if (node->src[2]) {
247
+ GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
248
+ ggml_is_contiguous(node->src[2]), node->src[2]->name);
249
+ }
250
+ if (node->src[3]) {
251
+ GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
252
+ ggml_is_contiguous(node->src[3]), node->src[3]->name);
253
+ }
254
+ if (node) {
255
+ GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
256
+ node->name);
257
+ }
258
+ }
259
+ }
260
+
261
+ switch (node->op) {
262
+ case GGML_OP_CONCAT:
263
+ {
264
+ n_fuse = ggml_metal_op_concat(ctx, idx);
265
+ } break;
266
+ case GGML_OP_ADD:
267
+ case GGML_OP_SUB:
268
+ case GGML_OP_MUL:
269
+ case GGML_OP_DIV:
270
+ {
271
+ n_fuse = ggml_metal_op_bin(ctx, idx);
272
+ } break;
273
+ case GGML_OP_ADD_ID:
274
+ {
275
+ n_fuse = ggml_metal_op_add_id(ctx, idx);
276
+ } break;
277
+ case GGML_OP_REPEAT:
278
+ {
279
+ n_fuse = ggml_metal_op_repeat(ctx, idx);
280
+ } break;
281
+ case GGML_OP_ACC:
282
+ {
283
+ n_fuse = ggml_metal_op_acc(ctx, idx);
284
+ } break;
285
+ case GGML_OP_SCALE:
286
+ {
287
+ n_fuse = ggml_metal_op_scale(ctx, idx);
288
+ } break;
289
+ case GGML_OP_FILL:
290
+ {
291
+ n_fuse = ggml_metal_op_fill(ctx, idx);
292
+ } break;
293
+ case GGML_OP_CLAMP:
294
+ {
295
+ n_fuse = ggml_metal_op_clamp(ctx, idx);
296
+ } break;
297
+ case GGML_OP_SQR:
298
+ case GGML_OP_SQRT:
299
+ case GGML_OP_SIN:
300
+ case GGML_OP_COS:
301
+ case GGML_OP_LOG:
302
+ case GGML_OP_UNARY:
303
+ {
304
+ n_fuse = ggml_metal_op_unary(ctx, idx);
305
+ } break;
306
+ case GGML_OP_GLU:
307
+ {
308
+ n_fuse = ggml_metal_op_glu(ctx, idx);
309
+ } break;
310
+ case GGML_OP_SUM:
311
+ {
312
+ n_fuse = ggml_metal_op_sum(ctx, idx);
313
+ } break;
314
+ case GGML_OP_SUM_ROWS:
315
+ case GGML_OP_MEAN:
316
+ {
317
+ n_fuse = ggml_metal_op_sum_rows(ctx, idx);
318
+ } break;
319
+ case GGML_OP_CUMSUM:
320
+ {
321
+ n_fuse = ggml_metal_op_cumsum(ctx, idx);
322
+ } break;
323
+ case GGML_OP_SOFT_MAX:
324
+ {
325
+ n_fuse = ggml_metal_op_soft_max(ctx, idx);
326
+ } break;
327
+ case GGML_OP_SSM_CONV:
328
+ {
329
+ n_fuse = ggml_metal_op_ssm_conv(ctx, idx);
330
+ } break;
331
+ case GGML_OP_SSM_SCAN:
332
+ {
333
+ n_fuse = ggml_metal_op_ssm_scan(ctx, idx);
334
+ } break;
335
+ case GGML_OP_RWKV_WKV6:
336
+ case GGML_OP_RWKV_WKV7:
337
+ {
338
+ n_fuse = ggml_metal_op_rwkv(ctx, idx);
339
+ } break;
340
+ case GGML_OP_MUL_MAT:
341
+ {
342
+ n_fuse = ggml_metal_op_mul_mat(ctx, idx);
343
+ } break;
344
+ case GGML_OP_MUL_MAT_ID:
345
+ {
346
+ n_fuse = ggml_metal_op_mul_mat_id(ctx, idx);
347
+ } break;
348
+ case GGML_OP_GET_ROWS:
349
+ {
350
+ n_fuse = ggml_metal_op_get_rows(ctx, idx);
351
+ } break;
352
+ case GGML_OP_SET_ROWS:
353
+ {
354
+ n_fuse = ggml_metal_op_set_rows(ctx, idx);
355
+ } break;
356
+ case GGML_OP_L2_NORM:
357
+ {
358
+ n_fuse = ggml_metal_op_l2_norm(ctx, idx);
359
+ } break;
360
+ case GGML_OP_GROUP_NORM:
361
+ {
362
+ n_fuse = ggml_metal_op_group_norm(ctx, idx);
363
+ } break;
364
+ case GGML_OP_NORM:
365
+ case GGML_OP_RMS_NORM:
366
+ {
367
+ n_fuse = ggml_metal_op_norm(ctx, idx);
368
+ } break;
369
+ case GGML_OP_ROPE:
370
+ {
371
+ n_fuse = ggml_metal_op_rope(ctx, idx);
372
+ } break;
373
+ case GGML_OP_IM2COL:
374
+ {
375
+ n_fuse = ggml_metal_op_im2col(ctx, idx);
376
+ } break;
377
+ case GGML_OP_CONV_2D:
378
+ {
379
+ n_fuse = ggml_metal_op_conv_2d(ctx, idx);
380
+ } break;
381
+ case GGML_OP_CONV_TRANSPOSE_1D:
382
+ {
383
+ n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
384
+ } break;
385
+ case GGML_OP_CONV_TRANSPOSE_2D:
386
+ {
387
+ n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
388
+ } break;
389
+ case GGML_OP_UPSCALE:
390
+ {
391
+ n_fuse = ggml_metal_op_upscale(ctx, idx);
392
+ } break;
393
+ case GGML_OP_PAD:
394
+ {
395
+ n_fuse = ggml_metal_op_pad(ctx, idx);
396
+ } break;
397
+ case GGML_OP_PAD_REFLECT_1D:
398
+ {
399
+ n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
400
+ } break;
401
+ case GGML_OP_ARANGE:
402
+ {
403
+ n_fuse = ggml_metal_op_arange(ctx, idx);
404
+ } break;
405
+ case GGML_OP_TIMESTEP_EMBEDDING:
406
+ {
407
+ n_fuse = ggml_metal_op_timestep_embedding(ctx, idx);
408
+ } break;
409
+ case GGML_OP_ARGSORT:
410
+ {
411
+ n_fuse = ggml_metal_op_argsort(ctx, idx);
412
+ } break;
413
+ case GGML_OP_TOP_K:
414
+ {
415
+ n_fuse = ggml_metal_op_top_k(ctx, idx);
416
+ } break;
417
+ case GGML_OP_LEAKY_RELU:
418
+ {
419
+ n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
420
+ } break;
421
+ case GGML_OP_TRI:
422
+ {
423
+ n_fuse = ggml_metal_op_tri(ctx, idx);
424
+ } break;
425
+ case GGML_OP_FLASH_ATTN_EXT:
426
+ {
427
+ n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
428
+ } break;
429
+ case GGML_OP_DUP:
430
+ case GGML_OP_CPY:
431
+ case GGML_OP_CONT:
432
+ {
433
+ n_fuse = ggml_metal_op_cpy(ctx, idx);
434
+ } break;
435
+ case GGML_OP_POOL_2D:
436
+ {
437
+ n_fuse = ggml_metal_op_pool_2d(ctx, idx);
438
+ } break;
439
+ case GGML_OP_ARGMAX:
440
+ {
441
+ n_fuse = ggml_metal_op_argmax(ctx, idx);
442
+ } break;
443
+ case GGML_OP_OPT_STEP_ADAMW:
444
+ {
445
+ n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
446
+ } break;
447
+ case GGML_OP_OPT_STEP_SGD:
448
+ {
449
+ n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
450
+ } break;
451
+ case GGML_OP_COUNT_EQUAL:
452
+ {
453
+ n_fuse = ggml_metal_op_count_equal(ctx, idx);
454
+ } break;
455
+ default:
456
+ {
457
+ GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
458
+ GGML_ABORT("fatal error");
459
+ }
460
+ }
461
+
462
+ if (ctx->debug_graph > 0) {
463
+ if (n_fuse > 1) {
464
+ GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
465
+ }
466
+ }
467
+
468
+ // update the mem ranges in the encoding context
469
+ for (int i = 0; i < n_fuse; ++i) {
470
+ if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
471
+ ggml_metal_op_concurrency_reset(ctx);
472
+ }
473
+ }
474
+
475
+ return n_fuse;
476
+ }
477
+
478
+ int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
479
+ if (ctx->use_capture) {
480
+ ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
481
+ }
482
+
483
+ int res = ggml_metal_op_encode_impl(ctx, idx);
484
+ if (idx + res > ctx->n_nodes()) {
485
+ GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
486
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
487
+ }
488
+
489
+ if (ctx->use_capture) {
490
+ ggml_metal_encoder_debug_group_pop(ctx->enc);
491
+ }
492
+
493
+ return res;
494
+ }
495
+
496
+ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
497
+ ggml_tensor * op = ctx->node(idx);
498
+
499
+ ggml_metal_library_t lib = ctx->lib;
500
+ ggml_metal_encoder_t enc = ctx->enc;
501
+
502
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
503
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
504
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
505
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
506
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
507
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
508
+
509
+ const int32_t dim = ((const int32_t *) op->op_params)[0];
510
+
511
+ ggml_metal_kargs_concat args = {
512
+ /*.ne00 =*/ ne00,
513
+ /*.ne01 =*/ ne01,
514
+ /*.ne02 =*/ ne02,
515
+ /*.ne03 =*/ ne03,
516
+ /*.nb00 =*/ nb00,
517
+ /*.nb01 =*/ nb01,
518
+ /*.nb02 =*/ nb02,
519
+ /*.nb03 =*/ nb03,
520
+ /*.ne10 =*/ ne10,
521
+ /*.ne11 =*/ ne11,
522
+ /*.ne12 =*/ ne12,
523
+ /*.ne13 =*/ ne13,
524
+ /*.nb10 =*/ nb10,
525
+ /*.nb11 =*/ nb11,
526
+ /*.nb12 =*/ nb12,
527
+ /*.nb13 =*/ nb13,
528
+ /*.ne0 =*/ ne0,
529
+ /*.ne1 =*/ ne1,
530
+ /*.ne2 =*/ ne2,
531
+ /*.ne3 =*/ ne3,
532
+ /*.nb0 =*/ nb0,
533
+ /*.nb1 =*/ nb1,
534
+ /*.nb2 =*/ nb2,
535
+ /*.nb3 =*/ nb3,
536
+ /*.dim =*/ dim,
537
+ };
538
+
539
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
540
+
541
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
542
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
543
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
544
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
545
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
546
+
547
+ const int nth = std::min(1024, ne0);
548
+
549
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
550
+
551
+ return 1;
552
+ }
553
+
554
+ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
555
+ ggml_tensor * op = ctx->node(idx);
556
+
557
+ ggml_metal_library_t lib = ctx->lib;
558
+ ggml_metal_encoder_t enc = ctx->enc;
559
+
560
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
561
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
562
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
563
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
564
+
565
+ auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
566
+
567
+ ggml_metal_kargs_repeat args = {
568
+ /*.ne00 =*/ ne00,
569
+ /*.ne01 =*/ ne01,
570
+ /*.ne02 =*/ ne02,
571
+ /*.ne03 =*/ ne03,
572
+ /*.nb00 =*/ nb00,
573
+ /*.nb01 =*/ nb01,
574
+ /*.nb02 =*/ nb02,
575
+ /*.nb03 =*/ nb03,
576
+ /*.ne0 =*/ ne0,
577
+ /*.ne1 =*/ ne1,
578
+ /*.ne2 =*/ ne2,
579
+ /*.ne3 =*/ ne3,
580
+ /*.nb0 =*/ nb0,
581
+ /*.nb1 =*/ nb1,
582
+ /*.nb2 =*/ nb2,
583
+ /*.nb3 =*/ nb3,
584
+ };
585
+
586
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
587
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
588
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
589
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
590
+
591
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
592
+
593
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
594
+
595
+ return 1;
596
+ }
597
+
598
+ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
599
+ ggml_tensor * op = ctx->node(idx);
600
+
601
+ ggml_metal_library_t lib = ctx->lib;
602
+ ggml_metal_encoder_t enc = ctx->enc;
603
+
604
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
605
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
606
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
607
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
608
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
609
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
610
+
611
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
612
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
613
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
614
+
615
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
616
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
617
+
618
+ const size_t pnb1 = ((const int32_t *) op->op_params)[0];
619
+ const size_t pnb2 = ((const int32_t *) op->op_params)[1];
620
+ const size_t pnb3 = ((const int32_t *) op->op_params)[2];
621
+ const size_t offs = ((const int32_t *) op->op_params)[3];
622
+
623
+ const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
624
+
625
+ if (!inplace) {
626
+ // run a separete kernel to cpy src->dst
627
+ // not sure how to avoid this
628
+ // TODO: make a simpler cpy_bytes kernel
629
+
630
+ //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
631
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
632
+
633
+ ggml_metal_kargs_cpy args = {
634
+ /*.nk0 =*/ ne00,
635
+ /*.ne00 =*/ ne00,
636
+ /*.ne01 =*/ ne01,
637
+ /*.ne02 =*/ ne02,
638
+ /*.ne03 =*/ ne03,
639
+ /*.nb00 =*/ nb00,
640
+ /*.nb01 =*/ nb01,
641
+ /*.nb02 =*/ nb02,
642
+ /*.nb03 =*/ nb03,
643
+ /*.ne0 =*/ ne0,
644
+ /*.ne1 =*/ ne1,
645
+ /*.ne2 =*/ ne2,
646
+ /*.ne3 =*/ ne3,
647
+ /*.nb0 =*/ nb0,
648
+ /*.nb1 =*/ nb1,
649
+ /*.nb2 =*/ nb2,
650
+ /*.nb3 =*/ nb3,
651
+ };
652
+
653
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
654
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
655
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
656
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
657
+
658
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
659
+
660
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
661
+
662
+ ggml_metal_op_concurrency_reset(ctx);
663
+ }
664
+
665
+ ggml_metal_kargs_bin args = {
666
+ /*.ne00 =*/ ne00,
667
+ /*.ne01 =*/ ne01,
668
+ /*.ne02 =*/ ne02,
669
+ /*.ne03 =*/ ne03,
670
+ /*.nb00 =*/ nb00,
671
+ /*.nb01 =*/ pnb1,
672
+ /*.nb02 =*/ pnb2,
673
+ /*.nb03 =*/ pnb3,
674
+ /*.ne10 =*/ ne10,
675
+ /*.ne11 =*/ ne11,
676
+ /*.ne12 =*/ ne12,
677
+ /*.ne13 =*/ ne13,
678
+ /*.nb10 =*/ nb10,
679
+ /*.nb11 =*/ nb11,
680
+ /*.nb12 =*/ nb12,
681
+ /*.nb13 =*/ nb13,
682
+ /*.ne0 =*/ ne0,
683
+ /*.ne1 =*/ ne1,
684
+ /*.ne2 =*/ ne2,
685
+ /*.ne3 =*/ ne3,
686
+ /*.nb0 =*/ nb0,
687
+ /*.nb1 =*/ pnb1,
688
+ /*.nb2 =*/ pnb2,
689
+ /*.nb3 =*/ pnb3,
690
+ /*.offs =*/ offs,
691
+ /*.o1 =*/ { 0 },
692
+ };
693
+
694
+ auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
695
+
696
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
697
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
698
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
699
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
700
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
701
+
702
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
703
+
704
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
705
+
706
+ return 1;
707
+ }
708
+
709
+ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
710
+ ggml_tensor * op = ctx->node(idx);
711
+
712
+ ggml_metal_library_t lib = ctx->lib;
713
+ ggml_metal_encoder_t enc = ctx->enc;
714
+
715
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
716
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
717
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
718
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
719
+
720
+ float scale;
721
+ float bias;
722
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
723
+ memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
724
+
725
+ ggml_metal_kargs_scale args = {
726
+ /*.scale =*/ scale,
727
+ /*.bias =*/ bias,
728
+ };
729
+
730
+ int64_t n = ggml_nelements(op);
731
+
732
+ if (n % 4 == 0) {
733
+ n /= 4;
734
+ }
735
+
736
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
737
+
738
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
739
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
740
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
741
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
742
+
743
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
744
+
745
+ return 1;
746
+ }
747
+
748
+ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
749
+ ggml_tensor * op = ctx->node(idx);
750
+
751
+ ggml_metal_library_t lib = ctx->lib;
752
+ ggml_metal_encoder_t enc = ctx->enc;
753
+
754
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
755
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
756
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
757
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
758
+
759
+ const float val = ggml_get_op_params_f32(op, 0);
760
+
761
+ ggml_metal_kargs_fill args = {
762
+ /*.val =*/ val
763
+ };
764
+
765
+ int64_t n = ggml_nelements(op);
766
+
767
+ if (n % 4 == 0) {
768
+ n /= 4;
769
+ }
770
+
771
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
772
+
773
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
774
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
775
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
776
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
777
+
778
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
779
+
780
+ return 1;
781
+ }
782
+
783
+ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
784
+ ggml_tensor * op = ctx->node(idx);
785
+
786
+ ggml_metal_library_t lib = ctx->lib;
787
+ ggml_metal_encoder_t enc = ctx->enc;
788
+
789
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
790
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
791
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
792
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
793
+
794
+ float min;
795
+ float max;
796
+ memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
797
+ memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
798
+
799
+ ggml_metal_kargs_clamp args = {
800
+ /*.min =*/ min,
801
+ /*.max =*/ max,
802
+ };
803
+
804
+ int64_t n = ggml_nelements(op);
805
+
806
+ if (n % 4 == 0) {
807
+ n /= 4;
808
+ }
809
+
810
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
811
+
812
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
813
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
814
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
815
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
816
+
817
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
818
+
819
+ return 1;
820
+ }
821
+
822
+ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
823
+ ggml_tensor * op = ctx->node(idx);
824
+
825
+ ggml_metal_library_t lib = ctx->lib;
826
+ ggml_metal_encoder_t enc = ctx->enc;
827
+
828
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
829
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
830
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
831
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
832
+
833
+ int64_t n = ggml_nelements(op);
834
+
835
+ if (n % 4 == 0) {
836
+ n /= 4;
837
+ }
838
+
839
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
840
+
841
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
842
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
843
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
844
+
845
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
846
+
847
+ return 1;
848
+ }
849
+
850
+ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
851
+ ggml_tensor * op = ctx->node(idx);
852
+
853
+ ggml_metal_library_t lib = ctx->lib;
854
+ ggml_metal_encoder_t enc = ctx->enc;
855
+
856
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
857
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
858
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
859
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
860
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
861
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
862
+
863
+ if (op->src[1]) {
864
+ GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
865
+ }
866
+
867
+ auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
868
+
869
+ const int32_t swp = ggml_get_op_params_i32(op, 1);
870
+ const float alpha = ggml_get_op_params_f32(op, 2);
871
+ const float limit = ggml_get_op_params_f32(op, 3);
872
+
873
+ const int32_t i00 = swp ? ne0 : 0;
874
+ const int32_t i10 = swp ? 0 : ne0;
875
+
876
+ ggml_metal_kargs_glu args = {
877
+ /*.ne00 =*/ ne00,
878
+ /*.nb01 =*/ nb01,
879
+ /*.ne10 =*/ op->src[1] ? ne10 : ne00,
880
+ /*.nb11 =*/ op->src[1] ? nb11 : nb01,
881
+ /*.ne0 =*/ ne0,
882
+ /*.nb1 =*/ nb1,
883
+ /*.i00 =*/ op->src[1] ? 0 : i00,
884
+ /*.i10 =*/ op->src[1] ? 0 : i10,
885
+ /*.alpha=*/ alpha,
886
+ /*.limit=*/ limit
887
+ };
888
+
889
+ const int64_t nrows = ggml_nrows(op->src[0]);
890
+
891
+ const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
892
+
893
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
894
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
895
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
896
+ if (op->src[1]) {
897
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
898
+ } else {
899
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2);
900
+ }
901
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
902
+
903
+ ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
904
+
905
+ return 1;
906
+ }
907
+
908
+ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
909
+ ggml_tensor * op = ctx->node(idx);
910
+
911
+ ggml_metal_library_t lib = ctx->lib;
912
+ ggml_metal_encoder_t enc = ctx->enc;
913
+
914
+ const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
915
+
916
+ ggml_metal_kargs_sum args = {
917
+ /*.np =*/ n,
918
+ };
919
+
920
+ auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
921
+
922
+ int nth = 32; // SIMD width
923
+
924
+ while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
925
+ nth *= 2;
926
+ }
927
+
928
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
929
+ nth = std::min(nth, (int) n);
930
+
931
+ const int nsg = (nth + 31) / 32;
932
+
933
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
934
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
935
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
936
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
937
+
938
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
939
+
940
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
941
+
942
+ return 1;
943
+ }
944
+
945
+ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
946
+ ggml_tensor * op = ctx->node(idx);
947
+
948
+ ggml_metal_library_t lib = ctx->lib;
949
+ ggml_metal_encoder_t enc = ctx->enc;
950
+
951
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
952
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
953
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
954
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
955
+
956
+ ggml_metal_kargs_sum_rows args = {
957
+ /*.ne00 =*/ ne00,
958
+ /*.ne01 =*/ ne01,
959
+ /*.ne02 =*/ ne02,
960
+ /*.ne03 =*/ ne03,
961
+ /*.nb00 =*/ nb00,
962
+ /*.nb01 =*/ nb01,
963
+ /*.nb02 =*/ nb02,
964
+ /*.nb03 =*/ nb03,
965
+ /*.ne0 =*/ ne0,
966
+ /*.ne1 =*/ ne1,
967
+ /*.ne2 =*/ ne2,
968
+ /*.ne3 =*/ ne3,
969
+ /*.nb0 =*/ nb0,
970
+ /*.nb1 =*/ nb1,
971
+ /*.nb2 =*/ nb2,
972
+ /*.nb3 =*/ nb3,
973
+ };
974
+
975
+ auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
976
+
977
+ int nth = 32; // SIMD width
978
+
979
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
980
+ nth *= 2;
981
+ }
982
+
983
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
984
+ nth = std::min(nth, ne00);
985
+
986
+ const size_t smem = pipeline.smem;
987
+
988
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
989
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
990
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
991
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
992
+
993
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
994
+
995
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
996
+
997
+ return 1;
998
+ }
999
+
1000
+ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
1001
+ ggml_tensor * op = ctx->node(idx);
1002
+
1003
+ ggml_metal_library_t lib = ctx->lib;
1004
+ ggml_metal_encoder_t enc = ctx->enc;
1005
+
1006
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
1007
+
1008
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1009
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1010
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1011
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1012
+
1013
+ auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
1014
+
1015
+ int nth = 1;
1016
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
1017
+ nth *= 2;
1018
+ }
1019
+
1020
+ GGML_ASSERT(ne00 <= nth*nth);
1021
+
1022
+ const int64_t net0 = (ne00 + nth - 1) / nth;
1023
+ const int64_t net1 = ne01;
1024
+ const int64_t net2 = ne02;
1025
+ const int64_t net3 = ne03;
1026
+
1027
+ const uint64_t nbt0 = sizeof(float);
1028
+ const uint64_t nbt1 = net0*nbt0;
1029
+ const uint64_t nbt2 = net1*nbt1;
1030
+ const uint64_t nbt3 = net2*nbt2;
1031
+
1032
+ const size_t smem = GGML_PAD(32*sizeof(float), 16);
1033
+
1034
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
1035
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
1036
+
1037
+ ggml_metal_buffer_id bid_tmp = bid_dst;
1038
+ bid_tmp.offs += ggml_nbytes(op);
1039
+
1040
+ {
1041
+ ggml_metal_kargs_cumsum_blk args = {
1042
+ /*.ne00 =*/ ne00,
1043
+ /*.ne01 =*/ ne01,
1044
+ /*.ne02 =*/ ne02,
1045
+ /*.ne03 =*/ ne03,
1046
+ /*.nb00 =*/ nb00,
1047
+ /*.nb01 =*/ nb01,
1048
+ /*.nb02 =*/ nb02,
1049
+ /*.nb03 =*/ nb03,
1050
+ /*.net0 =*/ net0,
1051
+ /*.net1 =*/ net1,
1052
+ /*.net2 =*/ net2,
1053
+ /*.net3 =*/ net3,
1054
+ /*.nbt0 =*/ nbt0,
1055
+ /*.nbt1 =*/ nbt1,
1056
+ /*.nbt2 =*/ nbt2,
1057
+ /*.nbt3 =*/ nbt3,
1058
+ /*.outb =*/ ne00 > nth,
1059
+ };
1060
+
1061
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1062
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1063
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1064
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1065
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
1066
+
1067
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1068
+
1069
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1070
+ }
1071
+
1072
+ if (ne00 > nth) {
1073
+ ggml_metal_op_concurrency_reset(ctx);
1074
+
1075
+ {
1076
+ ggml_metal_kargs_cumsum_blk args = {
1077
+ /*.ne00 =*/ net0,
1078
+ /*.ne01 =*/ net1,
1079
+ /*.ne02 =*/ net2,
1080
+ /*.ne03 =*/ net3,
1081
+ /*.nb00 =*/ nbt0,
1082
+ /*.nb01 =*/ nbt1,
1083
+ /*.nb02 =*/ nbt2,
1084
+ /*.nb03 =*/ nbt3,
1085
+ /*.net0 =*/ net0,
1086
+ /*.net1 =*/ net1,
1087
+ /*.net2 =*/ net2,
1088
+ /*.net3 =*/ net3,
1089
+ /*.nbt0 =*/ nbt0,
1090
+ /*.nbt1 =*/ nbt1,
1091
+ /*.nbt2 =*/ nbt2,
1092
+ /*.nbt3 =*/ nbt3,
1093
+ /*.outb =*/ false,
1094
+ };
1095
+
1096
+ ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1097
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1098
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1099
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1100
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
1101
+
1102
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1103
+
1104
+ ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1105
+ }
1106
+
1107
+ ggml_metal_op_concurrency_reset(ctx);
1108
+
1109
+ {
1110
+ auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1111
+
1112
+ ggml_metal_kargs_cumsum_add args = {
1113
+ /*.ne00 =*/ ne00,
1114
+ /*.ne01 =*/ ne01,
1115
+ /*.ne02 =*/ ne02,
1116
+ /*.ne03 =*/ ne03,
1117
+ /*.nb00 =*/ nb00,
1118
+ /*.nb01 =*/ nb01,
1119
+ /*.nb02 =*/ nb02,
1120
+ /*.nb03 =*/ nb03,
1121
+ /*.net0 =*/ net0,
1122
+ /*.net1 =*/ net1,
1123
+ /*.net2 =*/ net2,
1124
+ /*.net3 =*/ net3,
1125
+ /*.nbt0 =*/ nbt0,
1126
+ /*.nbt1 =*/ nbt1,
1127
+ /*.nbt2 =*/ nbt2,
1128
+ /*.nbt3 =*/ nbt3,
1129
+ };
1130
+
1131
+ ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1132
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1133
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1134
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1135
+
1136
+ ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1137
+ }
1138
+ }
1139
+
1140
+ return 1;
1141
+ }
1142
+
1143
+ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
1144
+ ggml_tensor * op = ctx->node(idx);
1145
+
1146
+ ggml_metal_library_t lib = ctx->lib;
1147
+ ggml_metal_encoder_t enc = ctx->enc;
1148
+
1149
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1150
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1151
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1152
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1153
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1154
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1155
+
1156
+ auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1157
+
1158
+ ggml_metal_kargs_get_rows args = {
1159
+ /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
1160
+ /*.ne00 =*/ ne00,
1161
+ /*.nb01 =*/ nb01,
1162
+ /*.nb02 =*/ nb02,
1163
+ /*.nb03 =*/ nb03,
1164
+ /*.ne10 =*/ ne10,
1165
+ /*.nb10 =*/ nb10,
1166
+ /*.nb11 =*/ nb11,
1167
+ /*.nb12 =*/ nb12,
1168
+ /*.nb1 =*/ nb1,
1169
+ /*.nb2 =*/ nb2,
1170
+ /*.nb3 =*/ nb3,
1171
+ };
1172
+
1173
+ const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1174
+
1175
+ const int nw0 = (args.ne00t + nth - 1)/nth;
1176
+
1177
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1178
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1179
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1180
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1181
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1182
+
1183
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
1184
+
1185
+ return 1;
1186
+ }
1187
+
1188
+ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
1189
+ ggml_tensor * op = ctx->node(idx);
1190
+
1191
+ ggml_metal_library_t lib = ctx->lib;
1192
+ ggml_metal_encoder_t enc = ctx->enc;
1193
+
1194
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1195
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1196
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1197
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1198
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1199
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1200
+
1201
+ auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1202
+
1203
+ const int32_t nk0 = ne0/ggml_blck_size(op->type);
1204
+
1205
+ int nth = 32; // SIMD width
1206
+
1207
+ while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1208
+ nth *= 2;
1209
+ }
1210
+
1211
+ int nrptg = 1;
1212
+ if (nth > nk0) {
1213
+ nrptg = (nth + nk0 - 1)/nk0;
1214
+ nth = nk0;
1215
+
1216
+ if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1217
+ nrptg--;
1218
+ }
1219
+ }
1220
+
1221
+ nth = std::min(nth, nk0);
1222
+
1223
+ ggml_metal_kargs_set_rows args = {
1224
+ /*.nk0 =*/ nk0,
1225
+ /*.ne01 =*/ ne01,
1226
+ /*.nb01 =*/ nb01,
1227
+ /*.nb02 =*/ nb02,
1228
+ /*.nb03 =*/ nb03,
1229
+ /*.ne11 =*/ ne11,
1230
+ /*.ne12 =*/ ne12,
1231
+ /*.nb10 =*/ nb10,
1232
+ /*.nb11 =*/ nb11,
1233
+ /*.nb12 =*/ nb12,
1234
+ /*.nb1 =*/ nb1,
1235
+ /*.nb2 =*/ nb2,
1236
+ /*.nb3 =*/ nb3,
1237
+ };
1238
+
1239
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1240
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1241
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1242
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1243
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1244
+
1245
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1246
+
1247
+ return 1;
1248
+ }
1249
+
1250
+ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1251
+ ggml_tensor * op = ctx->node(idx);
1252
+
1253
+ ggml_metal_library_t lib = ctx->lib;
1254
+ ggml_metal_encoder_t enc = ctx->enc;
1255
+
1256
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1257
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1258
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1259
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1260
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1261
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1262
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1263
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1264
+
1265
+ float scale;
1266
+ float max_bias;
1267
+
1268
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
1269
+ memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
1270
+
1271
+ const uint32_t n_head = op->src[0]->ne[2];
1272
+ const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1273
+
1274
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1275
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1276
+
1277
+ // softmax
1278
+
1279
+ ggml_metal_kargs_soft_max args = {
1280
+ /*.ne00 =*/ ne00,
1281
+ /*.ne01 =*/ ne01,
1282
+ /*.ne02 =*/ ne02,
1283
+ /*.nb01 =*/ nb01,
1284
+ /*.nb02 =*/ nb02,
1285
+ /*.nb03 =*/ nb03,
1286
+ /*.ne11 =*/ ne11,
1287
+ /*.ne12 =*/ ne12,
1288
+ /*.ne13 =*/ ne13,
1289
+ /*.nb11 =*/ nb11,
1290
+ /*.nb12 =*/ nb12,
1291
+ /*.nb13 =*/ nb13,
1292
+ /*.nb1 =*/ nb1,
1293
+ /*.nb2 =*/ nb2,
1294
+ /*.nb3 =*/ nb3,
1295
+ /*.scale =*/ scale,
1296
+ /*.max_bias =*/ max_bias,
1297
+ /*.m0 =*/ m0,
1298
+ /*.m1 =*/ m1,
1299
+ /*.n_head_log2 =*/ n_head_log2,
1300
+ };
1301
+
1302
+ auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1303
+
1304
+ int nth = 32; // SIMD width
1305
+
1306
+ if (ne00%4 == 0) {
1307
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1308
+ nth *= 2;
1309
+ }
1310
+ } else {
1311
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1312
+ nth *= 2;
1313
+ }
1314
+ }
1315
+
1316
+ const size_t smem = pipeline.smem;
1317
+
1318
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1319
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1320
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1321
+ if (op->src[1]) {
1322
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1323
+ } else {
1324
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
1325
+ }
1326
+ if (op->src[2]) {
1327
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1328
+ } else {
1329
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
1330
+ }
1331
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
1332
+
1333
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1334
+
1335
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1336
+
1337
+ return 1;
1338
+ }
1339
+
1340
+ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1341
+ ggml_tensor * op = ctx->node(idx);
1342
+
1343
+ ggml_metal_library_t lib = ctx->lib;
1344
+ ggml_metal_encoder_t enc = ctx->enc;
1345
+
1346
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1347
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1348
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1349
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1350
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1351
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1352
+
1353
+ ggml_metal_kargs_ssm_conv args = {
1354
+ /*.ne00 =*/ ne00,
1355
+ /*.ne01 =*/ ne01,
1356
+ /*.ne02 =*/ ne02,
1357
+ /*.nb00 =*/ nb00,
1358
+ /*.nb01 =*/ nb01,
1359
+ /*.nb02 =*/ nb02,
1360
+ /*.ne10 =*/ ne10,
1361
+ /*.ne11 =*/ ne11,
1362
+ /*.nb10 =*/ nb10,
1363
+ /*.nb11 =*/ nb11,
1364
+ /*.ne0 =*/ ne0,
1365
+ /*.ne1 =*/ ne1,
1366
+ /*.ne2 =*/ ne2,
1367
+ /*.nb0 =*/ nb0,
1368
+ /*.nb1 =*/ nb1,
1369
+ /*.nb2 =*/ nb2,
1370
+ };
1371
+
1372
+ // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1373
+ const bool use_batched = (ne1 > 1);
1374
+
1375
+ if (use_batched) {
1376
+ // Determine the smallest power of 2 that's >= ne1, but <= 256
1377
+ int BATCH_SIZE;
1378
+ if (ne1 > 128) BATCH_SIZE = 256;
1379
+ else if (ne1 > 64 ) BATCH_SIZE = 128;
1380
+ else if (ne1 > 32 ) BATCH_SIZE = 64;
1381
+ else if (ne1 > 16 ) BATCH_SIZE = 32;
1382
+ else if (ne1 > 8 ) BATCH_SIZE = 16;
1383
+ else if (ne1 > 4 ) BATCH_SIZE = 8;
1384
+ else BATCH_SIZE = 2;
1385
+
1386
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
1387
+
1388
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1389
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1390
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1391
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1392
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1393
+
1394
+ // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1395
+ // Each threadgroup has BATCH_SIZE threads, each handling one token
1396
+ const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1397
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1398
+ } else {
1399
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1400
+
1401
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1402
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1403
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1404
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1405
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1406
+
1407
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1408
+ }
1409
+
1410
+ return 1;
1411
+ }
1412
+
1413
+ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1414
+ ggml_tensor * op = ctx->node(idx);
1415
+
1416
+ ggml_metal_library_t lib = ctx->lib;
1417
+ ggml_metal_encoder_t enc = ctx->enc;
1418
+
1419
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1420
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1421
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1422
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1423
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1424
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1425
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1426
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1427
+ GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne);
1428
+ GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb);
1429
+ GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne);
1430
+ GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb);
1431
+ GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1432
+ GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1433
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1434
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1435
+
1436
+ const ggml_tensor * src3 = op->src[3];
1437
+ const ggml_tensor * src4 = op->src[4];
1438
+ const ggml_tensor * src5 = op->src[5];
1439
+ const ggml_tensor * src6 = op->src[6];
1440
+
1441
+ GGML_ASSERT(src3);
1442
+ GGML_ASSERT(src4);
1443
+ GGML_ASSERT(src5);
1444
+ GGML_ASSERT(src6);
1445
+
1446
+ const int64_t d_state = ne00;
1447
+ const int64_t d_inner = ne01;
1448
+ const int64_t n_head = ne02;
1449
+ const int64_t n_group = ne41;
1450
+ const int64_t n_seq_tokens = ne12;
1451
+ const int64_t n_seqs = ne13;
1452
+
1453
+ ggml_metal_kargs_ssm_scan args = {
1454
+ /*.d_state =*/ d_state,
1455
+ /*.d_inner =*/ d_inner,
1456
+ /*.n_head =*/ n_head,
1457
+ /*.n_group =*/ n_group,
1458
+ /*.n_seq_tokens =*/ n_seq_tokens,
1459
+ /*.n_seqs =*/ n_seqs,
1460
+ /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
1461
+ /*.nb00 =*/ nb00,
1462
+ /*.nb01 =*/ nb01,
1463
+ /*.nb02 =*/ nb02,
1464
+ /*.nb03 =*/ nb03,
1465
+ /*.nb10 =*/ nb10,
1466
+ /*.nb11 =*/ nb11,
1467
+ /*.nb12 =*/ nb12,
1468
+ /*.ns12 =*/ nb12/nb10,
1469
+ /*.nb13 =*/ nb13,
1470
+ /*.nb20 =*/ nb20,
1471
+ /*.nb21 =*/ nb21,
1472
+ /*.ns21 =*/ nb21/nb20,
1473
+ /*.nb22 =*/ nb22,
1474
+ /*.ne30 =*/ ne30,
1475
+ /*.nb31 =*/ nb31,
1476
+ /*.nb41 =*/ nb41,
1477
+ /*.nb42 =*/ nb42,
1478
+ /*.ns42 =*/ nb42/nb40,
1479
+ /*.nb43 =*/ nb43,
1480
+ /*.nb51 =*/ nb51,
1481
+ /*.nb52 =*/ nb52,
1482
+ /*.ns52 =*/ nb52/nb50,
1483
+ /*.nb53 =*/ nb53,
1484
+ /*.nb0 =*/ nb0,
1485
+ };
1486
+
1487
+ auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1488
+
1489
+ GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1490
+
1491
+ const size_t smem = pipeline.smem;
1492
+
1493
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1494
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1495
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1496
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1497
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1498
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4);
1499
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5);
1500
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6);
1501
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
1502
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
1503
+
1504
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1505
+
1506
+ ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1507
+
1508
+ return 1;
1509
+ }
1510
+
1511
+ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1512
+ ggml_tensor * op = ctx->node(idx);
1513
+
1514
+ ggml_metal_library_t lib = ctx->lib;
1515
+ ggml_metal_encoder_t enc = ctx->enc;
1516
+
1517
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1518
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1519
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1520
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1521
+
1522
+ const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1523
+ const int64_t T = op->src[0]->ne[2];
1524
+ const int64_t C = op->ne[0];
1525
+ const int64_t H = op->src[0]->ne[1];
1526
+
1527
+ auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1528
+
1529
+ int ida = 0;
1530
+
1531
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1532
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1533
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1534
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1535
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1536
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1537
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1538
+ if (op->op == GGML_OP_RWKV_WKV7) {
1539
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1540
+ }
1541
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
1542
+ ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1543
+ ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1544
+ ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1545
+ ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1546
+
1547
+ ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1548
+
1549
+ return 1;
1550
+ }
1551
+
1552
+ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1553
+ ggml_tensor * op = ctx->node(idx);
1554
+
1555
+ ggml_metal_library_t lib = ctx->lib;
1556
+ ggml_metal_encoder_t enc = ctx->enc;
1557
+
1558
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1559
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1560
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1561
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1562
+
1563
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1564
+
1565
+ GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
1566
+
1567
+ int64_t nk0 = ne00;
1568
+ if (ggml_is_quantized(op->src[0]->type)) {
1569
+ nk0 = ne00/16;
1570
+ } else if (ggml_is_quantized(op->type)) {
1571
+ nk0 = ne00/ggml_blck_size(op->type);
1572
+ }
1573
+
1574
+ int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1575
+
1576
+ // when rows are small, we can batch them together in a single threadgroup
1577
+ int nrptg = 1;
1578
+
1579
+ // TODO: relax this constraint in the future
1580
+ if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
1581
+ if (nth > nk0) {
1582
+ nrptg = (nth + nk0 - 1)/nk0;
1583
+ nth = nk0;
1584
+
1585
+ if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1586
+ nrptg--;
1587
+ }
1588
+ }
1589
+ }
1590
+
1591
+ nth = std::min<int>(nth, nk0);
1592
+
1593
+ ggml_metal_kargs_cpy args = {
1594
+ /*.nk0 =*/ nk0,
1595
+ /*.ne00 =*/ ne00,
1596
+ /*.ne01 =*/ ne01,
1597
+ /*.ne02 =*/ ne02,
1598
+ /*.ne03 =*/ ne03,
1599
+ /*.nb00 =*/ nb00,
1600
+ /*.nb01 =*/ nb01,
1601
+ /*.nb02 =*/ nb02,
1602
+ /*.nb03 =*/ nb03,
1603
+ /*.ne0 =*/ ne0,
1604
+ /*.ne1 =*/ ne1,
1605
+ /*.ne2 =*/ ne2,
1606
+ /*.ne3 =*/ ne3,
1607
+ /*.nb0 =*/ nb0,
1608
+ /*.nb1 =*/ nb1,
1609
+ /*.nb2 =*/ nb2,
1610
+ /*.nb3 =*/ nb3,
1611
+ };
1612
+
1613
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1614
+
1615
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1616
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1617
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1618
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1619
+
1620
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1621
+
1622
+ return 1;
1623
+ }
1624
+
1625
+ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1626
+ ggml_tensor * op = ctx->node(idx);
1627
+
1628
+ ggml_metal_library_t lib = ctx->lib;
1629
+ ggml_metal_encoder_t enc = ctx->enc;
1630
+
1631
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1632
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1633
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1634
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1635
+
1636
+ const int32_t * opts = op->op_params;
1637
+ ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1638
+
1639
+ const int32_t k0 = opts[1];
1640
+ const int32_t k1 = opts[2];
1641
+ const int32_t s0 = opts[3];
1642
+ const int32_t s1 = opts[4];
1643
+ const int32_t p0 = opts[5];
1644
+ const int32_t p1 = opts[6];
1645
+
1646
+ const int64_t IH = op->src[0]->ne[1];
1647
+ const int64_t IW = op->src[0]->ne[0];
1648
+
1649
+ const int64_t N = op->ne[3];
1650
+ const int64_t OC = op->ne[2];
1651
+ const int64_t OH = op->ne[1];
1652
+ const int64_t OW = op->ne[0];
1653
+
1654
+ const int64_t np = N * OC * OH * OW;
1655
+
1656
+ ggml_metal_kargs_pool_2d args_pool_2d = {
1657
+ /* .k0 = */ k0,
1658
+ /* .k1 = */ k1,
1659
+ /* .s0 = */ s0,
1660
+ /* .s1 = */ s1,
1661
+ /* .p0 = */ p0,
1662
+ /* .p1 = */ p1,
1663
+ /* .IH = */ IH,
1664
+ /* .IW = */ IW,
1665
+ /* .OH = */ OH,
1666
+ /* .OW = */ OW,
1667
+ /* .np = */ np
1668
+ };
1669
+
1670
+ auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1671
+
1672
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1673
+ const int ntg = (np + nth - 1) / nth;
1674
+
1675
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1676
+ ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0);
1677
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1678
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1679
+
1680
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1681
+
1682
+ return 1;
1683
+ }
1684
+
1685
+ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1686
+ ggml_tensor * op = ctx->node(idx);
1687
+
1688
+ ggml_metal_library_t lib = ctx->lib;
1689
+ ggml_metal_encoder_t enc = ctx->enc;
1690
+
1691
+ const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
1692
+
1693
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1694
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1695
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1696
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1697
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1698
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1699
+
1700
+ GGML_ASSERT(ne00 == ne10);
1701
+
1702
+ GGML_ASSERT(ne12 % ne02 == 0);
1703
+ GGML_ASSERT(ne13 % ne03 == 0);
1704
+
1705
+ const int16_t r2 = ne12/ne02;
1706
+ const int16_t r3 = ne13/ne03;
1707
+
1708
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1709
+ // to the matrix-vector kernel
1710
+ const int ne11_mm_min = 8;
1711
+
1712
+ // first try to use small-batch mat-mv kernels
1713
+ // these should be efficient for BS [2, ~8]
1714
+ if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) &&
1715
+ (
1716
+ (
1717
+ (
1718
+ op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
1719
+ op->src[0]->type == GGML_TYPE_F16 ||
1720
+ op->src[0]->type == GGML_TYPE_Q4_0 ||
1721
+ op->src[0]->type == GGML_TYPE_Q4_1 ||
1722
+ op->src[0]->type == GGML_TYPE_Q5_0 ||
1723
+ op->src[0]->type == GGML_TYPE_Q5_1 ||
1724
+ op->src[0]->type == GGML_TYPE_Q8_0 ||
1725
+ op->src[0]->type == GGML_TYPE_MXFP4 ||
1726
+ op->src[0]->type == GGML_TYPE_IQ4_NL ||
1727
+ false) && (ne11 >= 2 && ne11 <= 8)
1728
+ ) ||
1729
+ (
1730
+ (
1731
+ op->src[0]->type == GGML_TYPE_Q4_K ||
1732
+ op->src[0]->type == GGML_TYPE_Q5_K ||
1733
+ op->src[0]->type == GGML_TYPE_Q6_K ||
1734
+ false) && (ne11 >= 4 && ne11 <= 8)
1735
+ )
1736
+ )
1737
+ ) {
1738
+ // TODO: determine the optimal parameters based on grid utilization
1739
+ // I still don't know why we should not always use the maximum available threads:
1740
+ //
1741
+ // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
1742
+ //
1743
+ // my current hypothesis is that the work grid is not evenly divisible for different nsg
1744
+ // values and there can be some tail effects when nsg is high. need to confirm this
1745
+ //
1746
+ const int nsg = 2; // num simdgroups per threadgroup
1747
+
1748
+ // num threads along row per simdgroup
1749
+ int16_t nxpsg = 0;
1750
+ if (ne00 % 256 == 0 && ne11 < 3) {
1751
+ nxpsg = 16;
1752
+ } else if (ne00 % 128 == 0) {
1753
+ nxpsg = 8;
1754
+ } else {
1755
+ nxpsg = 4;
1756
+ }
1757
+
1758
+ const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
1759
+ const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
1760
+ int16_t r1ptg = 4; // num src1 rows per threadgroup
1761
+
1762
+ // note: not sure how optimal are those across all different hardware. there might be someting cleverer
1763
+ switch (ne11) {
1764
+ case 2:
1765
+ r1ptg = 2; break;
1766
+ case 3:
1767
+ case 6:
1768
+ r1ptg = 3; break;
1769
+ case 4:
1770
+ case 7:
1771
+ case 8:
1772
+ r1ptg = 4; break;
1773
+ case 5:
1774
+ r1ptg = 5; break;
1775
+ default:
1776
+ GGML_ABORT("unsupported ne11");
1777
+ };
1778
+
1779
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1780
+
1781
+ ggml_metal_kargs_mul_mv_ext args = {
1782
+ /*.ne00 =*/ ne00,
1783
+ /*.ne01 =*/ ne01,
1784
+ /*.ne02 =*/ ne02,
1785
+ /*.nb00 =*/ nb00,
1786
+ /*.nb01 =*/ nb01,
1787
+ /*.nb02 =*/ nb02,
1788
+ /*.nb03 =*/ nb03,
1789
+ /*.ne10 =*/ ne10,
1790
+ /*.ne11 =*/ ne11,
1791
+ /*.ne12 =*/ ne12,
1792
+ /*.nb10 =*/ nb10,
1793
+ /*.nb11 =*/ nb11,
1794
+ /*.nb12 =*/ nb12,
1795
+ /*.nb13 =*/ nb13,
1796
+ /*.ne0 =*/ ne0,
1797
+ /*.ne1 =*/ ne1,
1798
+ /*.r2 =*/ r2,
1799
+ /*.r3 =*/ r3,
1800
+ };
1801
+
1802
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1803
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1804
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1805
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1806
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1807
+
1808
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1);
1809
+ } else if (
1810
+ !ggml_is_transposed(op->src[0]) &&
1811
+ !ggml_is_transposed(op->src[1]) &&
1812
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1813
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1814
+ props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1815
+ //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1816
+
1817
+ // some Metal matrix data types require aligned pointers
1818
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1819
+ //switch (op->src[0]->type) {
1820
+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1821
+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1822
+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1823
+ // default: break;
1824
+ //}
1825
+
1826
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
1827
+
1828
+ ggml_metal_kargs_mul_mm args = {
1829
+ /*.ne00 =*/ ne00,
1830
+ /*.ne02 =*/ ne02,
1831
+ /*.nb01 =*/ nb01,
1832
+ /*.nb02 =*/ nb02,
1833
+ /*.nb03 =*/ nb03,
1834
+ /*.ne12 =*/ ne12,
1835
+ /*.nb10 =*/ nb10,
1836
+ /*.nb11 =*/ nb11,
1837
+ /*.nb12 =*/ nb12,
1838
+ /*.nb13 =*/ nb13,
1839
+ /*.ne0 =*/ ne0,
1840
+ /*.ne1 =*/ ne1,
1841
+ /*.r2 =*/ r2,
1842
+ /*.r3 =*/ r3,
1843
+ };
1844
+
1845
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1846
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1847
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1848
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1849
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1850
+
1851
+ const size_t smem = pipeline.smem;
1852
+
1853
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1854
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1855
+ } else {
1856
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1857
+
1858
+ const int nr0 = pipeline.nr0;
1859
+ const int nr1 = pipeline.nr1;
1860
+ const int nsg = pipeline.nsg;
1861
+
1862
+ const size_t smem = pipeline.smem;
1863
+
1864
+ ggml_metal_kargs_mul_mv args = {
1865
+ /*.ne00 =*/ ne00,
1866
+ /*.ne01 =*/ ne01,
1867
+ /*.ne02 =*/ ne02,
1868
+ /*.nb00 =*/ nb00,
1869
+ /*.nb01 =*/ nb01,
1870
+ /*.nb02 =*/ nb02,
1871
+ /*.nb03 =*/ nb03,
1872
+ /*.ne10 =*/ ne10,
1873
+ /*.ne11 =*/ ne11,
1874
+ /*.ne12 =*/ ne12,
1875
+ /*.nb10 =*/ nb10,
1876
+ /*.nb11 =*/ nb11,
1877
+ /*.nb12 =*/ nb12,
1878
+ /*.nb13 =*/ nb13,
1879
+ /*.ne0 =*/ ne0,
1880
+ /*.ne1 =*/ ne1,
1881
+ /*.nr0 =*/ nr0,
1882
+ /*.r2 =*/ r2,
1883
+ /*.r3 =*/ r3,
1884
+ };
1885
+
1886
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1887
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1888
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1889
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1890
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1891
+
1892
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1893
+
1894
+ if (op->src[0]->type == GGML_TYPE_F32 ||
1895
+ op->src[0]->type == GGML_TYPE_F16 ||
1896
+ op->src[0]->type == GGML_TYPE_BF16 ||
1897
+ op->src[0]->type == GGML_TYPE_Q8_0) {
1898
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1899
+ } else {
1900
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1901
+ }
1902
+ }
1903
+
1904
+ return 1;
1905
+ }
1906
+
1907
+ size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) {
1908
+ assert(op->op == GGML_OP_MUL_MAT_ID);
1909
+
1910
+ const int64_t ne02 = op->src[0]->ne[2]; // n_expert
1911
+
1912
+ return ggml_type_size(GGML_TYPE_I32)*ne02;
1913
+ }
1914
+
1915
+ size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
1916
+ assert(op->op == GGML_OP_MUL_MAT_ID);
1917
+
1918
+ const int64_t ne02 = op->src[0]->ne[2]; // n_expert
1919
+ const int64_t ne21 = op->src[2]->ne[1]; // n_token
1920
+
1921
+ return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
1922
+ }
1923
+
1924
+ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
1925
+ ggml_tensor * op = ctx->node(idx);
1926
+
1927
+ ggml_metal_library_t lib = ctx->lib;
1928
+ ggml_metal_encoder_t enc = ctx->enc;
1929
+
1930
+ const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
1931
+
1932
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1933
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1934
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1935
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1936
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1937
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1938
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1939
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1940
+
1941
+ // src2 = ids
1942
+ GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
1943
+
1944
+ GGML_ASSERT(!ggml_is_transposed(op->src[0]));
1945
+ GGML_ASSERT(!ggml_is_transposed(op->src[1]));
1946
+
1947
+ GGML_ASSERT(ne03 == 1);
1948
+ GGML_ASSERT(ne13 == 1);
1949
+
1950
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
1951
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
1952
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
1953
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
1954
+
1955
+ const uint32_t r2 = 1;
1956
+ const uint32_t r3 = 1;
1957
+
1958
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1959
+ // to the matrix-vector kernel
1960
+ // ne20 = n_used_experts
1961
+ // ne21 = n_rows (batch size)
1962
+ const int ne21_mm_id_min = 32;
1963
+
1964
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
1965
+ // some Metal matrix data types require aligned pointers
1966
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1967
+ //switch (op->src[0]->type) {
1968
+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1969
+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1970
+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1971
+ // default: break;
1972
+ //}
1973
+
1974
+ // extra buffers for intermediate id mapping
1975
+ ggml_metal_buffer_id bid_tpe = bid_dst;
1976
+ bid_tpe.offs += ggml_nbytes(op);
1977
+
1978
+ ggml_metal_buffer_id bid_ids = bid_tpe;
1979
+ bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op);
1980
+
1981
+ {
1982
+ ggml_metal_kargs_mul_mm_id_map0 args = {
1983
+ ne02,
1984
+ ne10,
1985
+ ne11, // n_expert_used (bcast)
1986
+ nb11,
1987
+ nb12,
1988
+ ne21, // n_tokens
1989
+ ne20, // n_expert_used
1990
+ nb21,
1991
+ };
1992
+
1993
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
1994
+
1995
+ const size_t smem = pipeline.smem;
1996
+
1997
+ GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1998
+
1999
+ GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2000
+
2001
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2002
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2003
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 1);
2004
+ ggml_metal_encoder_set_buffer (enc, bid_tpe, 2);
2005
+ ggml_metal_encoder_set_buffer (enc, bid_ids, 3);
2006
+
2007
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2008
+
2009
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1);
2010
+ }
2011
+
2012
+ // this barrier is always needed because the next kernel has to wait for the id maps to be computed
2013
+ ggml_metal_op_concurrency_reset(ctx);
2014
+
2015
+ {
2016
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
2017
+
2018
+ ggml_metal_kargs_mul_mm_id args = {
2019
+ /*.ne00 =*/ ne00,
2020
+ /*.ne02 =*/ ne02,
2021
+ /*.nb01 =*/ nb01,
2022
+ /*.nb02 =*/ nb02,
2023
+ /*.nb03 =*/ nb03,
2024
+ /*.ne11 =*/ ne11, // n_expert_used (bcast)
2025
+ /*.nb10 =*/ nb10,
2026
+ /*.nb11 =*/ nb11,
2027
+ /*.nb12 =*/ nb12,
2028
+ /*.nb13 =*/ nb13,
2029
+ /*.ne20 =*/ ne20, // n_expert_used
2030
+ /*.ne21 =*/ ne21, // n_tokens
2031
+ /*.ne0 =*/ ne0,
2032
+ /*.ne1 =*/ ne1,
2033
+ /*.r2 =*/ r2,
2034
+ /*.r3 =*/ r3,
2035
+ };
2036
+
2037
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2038
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2039
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2040
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2041
+ ggml_metal_encoder_set_buffer (enc, bid_tpe, 3);
2042
+ ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
2043
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
2044
+
2045
+ const size_t smem = pipeline.smem;
2046
+
2047
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2048
+
2049
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
2050
+ }
2051
+ } else {
2052
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
2053
+
2054
+ const int nr0 = pipeline.nr0;
2055
+ const int nr1 = pipeline.nr1;
2056
+ const int nsg = pipeline.nsg;
2057
+
2058
+ const size_t smem = pipeline.smem;
2059
+
2060
+ ggml_metal_kargs_mul_mv_id args = {
2061
+ /*.nei0 =*/ ne20,
2062
+ /*.nei1 =*/ ne21,
2063
+ /*.nbi1 =*/ nb21,
2064
+ /*.ne00 =*/ ne00,
2065
+ /*.ne01 =*/ ne01,
2066
+ /*.ne02 =*/ ne02,
2067
+ /*.nb00 =*/ nb00,
2068
+ /*.nb01 =*/ nb01,
2069
+ /*.nb02 =*/ nb02,
2070
+ /*.ne10 =*/ ne10,
2071
+ /*.ne11 =*/ ne11,
2072
+ /*.ne12 =*/ ne12,
2073
+ /*.ne13 =*/ ne13,
2074
+ /*.nb10 =*/ nb10,
2075
+ /*.nb11 =*/ nb11,
2076
+ /*.nb12 =*/ nb12,
2077
+ /*.ne0 =*/ ne0,
2078
+ /*.ne1 =*/ ne1,
2079
+ /*.nb1 =*/ nb1,
2080
+ /*.nr0 =*/ nr0,
2081
+ };
2082
+
2083
+ if (ggml_is_quantized(op->src[0]->type)) {
2084
+ GGML_ASSERT(ne00 >= nsg*nr0);
2085
+ }
2086
+
2087
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2088
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
2089
+ ggml_metal_encoder_set_buffer(enc, bid_src0, 1);
2090
+ ggml_metal_encoder_set_buffer(enc, bid_src1, 2);
2091
+ ggml_metal_encoder_set_buffer(enc, bid_dst, 3);
2092
+ ggml_metal_encoder_set_buffer(enc, bid_src2, 4);
2093
+
2094
+ const int64_t _ne1 = 1;
2095
+ const int64_t ne123 = ne20*ne21;
2096
+
2097
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2098
+
2099
+ if (op->src[0]->type == GGML_TYPE_F32 ||
2100
+ op->src[0]->type == GGML_TYPE_F16 ||
2101
+ op->src[0]->type == GGML_TYPE_BF16 ||
2102
+ op->src[0]->type == GGML_TYPE_Q8_0) {
2103
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2104
+ } else {
2105
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2106
+ }
2107
+ }
2108
+
2109
+ return 1;
2110
+ }
2111
+
2112
+ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
2113
+ ggml_tensor * op = ctx->node(idx);
2114
+
2115
+ ggml_metal_library_t lib = ctx->lib;
2116
+ ggml_metal_encoder_t enc = ctx->enc;
2117
+
2118
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2119
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2120
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2121
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2122
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2123
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2124
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2125
+
2126
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2127
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2128
+ GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
2129
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
2130
+
2131
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2132
+
2133
+ ggml_metal_kargs_add_id args = {
2134
+ /*.ne0 =*/ ne0,
2135
+ /*.ne1 =*/ ne1,
2136
+ /*.nb01 =*/ nb01,
2137
+ /*.nb02 =*/ nb02,
2138
+ /*.nb11 =*/ nb11,
2139
+ /*.nb21 =*/ nb21,
2140
+ };
2141
+
2142
+ auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
2143
+
2144
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2145
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2146
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2147
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2148
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2149
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
2150
+
2151
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
2152
+
2153
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1);
2154
+
2155
+ return 1;
2156
+ }
2157
+
2158
+ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
2159
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2160
+
2161
+ const int64_t ne00 = op->src[0]->ne[0]; // head size
2162
+ const int64_t ne01 = op->src[0]->ne[1]; // batch size
2163
+
2164
+ // use vec kernel if the batch size is small and if the head size is supported
2165
+ return (ne01 < 20) && (ne00 % 32 == 0);
2166
+ }
2167
+
2168
+ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
2169
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2170
+
2171
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2172
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2173
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2174
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2175
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2176
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2177
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2178
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2179
+
2180
+ size_t res = 0;
2181
+
2182
+ const bool has_mask = op->src[3] != nullptr;
2183
+
2184
+ // note: the non-vec kernel requires more extra memory, so always reserve for it
2185
+ GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
2186
+
2187
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2188
+ if (false) {
2189
+ // note: always reserve the padding space to avoid graph reallocations
2190
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2191
+ const bool has_kvpad = true;
2192
+
2193
+ if (has_kvpad) {
2194
+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
2195
+ nb11*ne12*ne13 +
2196
+ nb21*ne22*ne23 +
2197
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2198
+ }
2199
+ } else {
2200
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2201
+ const bool has_kvpad = true;
2202
+
2203
+ if (has_kvpad) {
2204
+ res += OP_FLASH_ATTN_EXT_NCPSG*(
2205
+ nb11*ne12*ne13 +
2206
+ nb21*ne22*ne23 +
2207
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2208
+ }
2209
+ }
2210
+
2211
+ return res;
2212
+ }
2213
+
2214
+ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2215
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2216
+
2217
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2218
+ //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2219
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2220
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2221
+ //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2222
+ //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2223
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2224
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2225
+
2226
+ size_t res = 0;
2227
+
2228
+ const bool has_mask = op->src[3] != nullptr;
2229
+
2230
+ if (!has_mask) {
2231
+ return res;
2232
+ }
2233
+
2234
+ const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
2235
+
2236
+ // this optimization is not useful for the vector kernels
2237
+ // note: always reserve the blk buffer to avoid graph reallocations
2238
+ //if (is_vec) {
2239
+ // return res;
2240
+ //}
2241
+
2242
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2243
+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2244
+
2245
+ const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2246
+ const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2247
+
2248
+ res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2249
+
2250
+ return res;
2251
+ }
2252
+
2253
+ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
2254
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2255
+
2256
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2257
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2258
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2259
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2260
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2261
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2262
+ //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2263
+ //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2264
+
2265
+ size_t res = 0;
2266
+
2267
+ // note: always reserve the temp buffer to avoid graph reallocations
2268
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2269
+ if (true) {
2270
+ const int64_t nwg = 32;
2271
+ const int64_t ne01_max = std::min(ne01, 32);
2272
+
2273
+ // temp buffer for writing the results from each workgroup
2274
+ // - ne20: the size of the Value head
2275
+ // - + 2: the S and M values for each intermediate result
2276
+ res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2277
+ }
2278
+
2279
+ return res;
2280
+ }
2281
+
2282
+ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2283
+ ggml_tensor * op = ctx->node(idx);
2284
+
2285
+ ggml_metal_library_t lib = ctx->lib;
2286
+ ggml_metal_encoder_t enc = ctx->enc;
2287
+
2288
+ const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
2289
+
2290
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2291
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2292
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2293
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2294
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2295
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2296
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2297
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2298
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2299
+ GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
2300
+
2301
+ GGML_ASSERT(ne00 % 4 == 0);
2302
+
2303
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2304
+ GGML_ASSERT(op->src[1]->type == op->src[2]->type);
2305
+
2306
+ //GGML_ASSERT(ggml_are_same_shape (src1, src2));
2307
+ GGML_ASSERT(ne11 == ne21);
2308
+ GGML_ASSERT(ne12 == ne22);
2309
+
2310
+ GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
2311
+ GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2312
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
2313
+
2314
+ float scale;
2315
+ float max_bias;
2316
+ float logit_softcap;
2317
+
2318
+ memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
2319
+ memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
2320
+ memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
2321
+
2322
+ if (logit_softcap != 0.0f) {
2323
+ scale /= logit_softcap;
2324
+ }
2325
+
2326
+ const bool has_mask = op->src[3] != NULL;
2327
+ const bool has_sinks = op->src[4] != NULL;
2328
+ const bool has_bias = max_bias != 0.0f;
2329
+ const bool has_scap = logit_softcap != 0.0f;
2330
+
2331
+ const uint32_t n_head = op->src[0]->ne[2];
2332
+ const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2333
+
2334
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2335
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2336
+
2337
+ GGML_ASSERT(ne01 < 65536);
2338
+
2339
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2340
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2341
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2342
+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2343
+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2344
+
2345
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2346
+
2347
+ ggml_metal_buffer_id bid_pad = bid_dst;
2348
+ bid_pad.offs += ggml_nbytes(op);
2349
+
2350
+ ggml_metal_buffer_id bid_blk = bid_pad;
2351
+ bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2352
+
2353
+ ggml_metal_buffer_id bid_tmp = bid_blk;
2354
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
2355
+
2356
+ if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
2357
+ // half8x8 kernel
2358
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2359
+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
2360
+
2361
+ GGML_ASSERT(nqptg <= 32);
2362
+ GGML_ASSERT(nqptg % 8 == 0);
2363
+ GGML_ASSERT(ncpsg % 32 == 0);
2364
+
2365
+ bool need_sync = false;
2366
+
2367
+ const bool has_kvpad = ne11 % ncpsg != 0;
2368
+
2369
+ if (has_kvpad) {
2370
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2371
+
2372
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
2373
+ /*.ne11 =*/ne11,
2374
+ /*.ne_12_2 =*/ne12,
2375
+ /*.ne_12_3 =*/ne13,
2376
+ /*.nb11 =*/nb11,
2377
+ /*.nb12 =*/nb12,
2378
+ /*.nb13 =*/nb13,
2379
+ /*.nb21 =*/nb21,
2380
+ /*.nb22 =*/nb22,
2381
+ /*.nb23 =*/nb23,
2382
+ /*.ne31 =*/ne31,
2383
+ /*.ne32 =*/ne32,
2384
+ /*.ne33 =*/ne33,
2385
+ /*.nb31 =*/nb31,
2386
+ /*.nb32 =*/nb32,
2387
+ /*.nb33 =*/nb33,
2388
+ };
2389
+
2390
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2391
+
2392
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2393
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2394
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2395
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2396
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2397
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2398
+
2399
+ assert(ne12 == ne22);
2400
+ assert(ne13 == ne23);
2401
+
2402
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2403
+
2404
+ need_sync = true;
2405
+ }
2406
+
2407
+ if (has_mask) {
2408
+ assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2409
+
2410
+ ggml_metal_kargs_flash_attn_ext_blk args0 = {
2411
+ /*.ne01 =*/ ne01,
2412
+ /*.ne30 =*/ ne30,
2413
+ /*.ne31 =*/ ne31,
2414
+ /*.ne32 =*/ ne32,
2415
+ /*.ne33 =*/ ne33,
2416
+ /*.nb31 =*/ nb31,
2417
+ /*.nb32 =*/ nb32,
2418
+ /*.nb33 =*/ nb33,
2419
+ };
2420
+
2421
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2422
+
2423
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2424
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2425
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2426
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2427
+
2428
+ const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2429
+ const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2430
+
2431
+ ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2432
+
2433
+ need_sync = true;
2434
+ }
2435
+
2436
+ if (need_sync) {
2437
+ ggml_metal_op_concurrency_reset(ctx);
2438
+ }
2439
+
2440
+ const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
2441
+
2442
+ // 2*(2*ncpsg)
2443
+ // ncpsg soft_max values + ncpsg mask values
2444
+ //
2445
+ // 16*32*(nsg)
2446
+ // the shared memory needed for the simdgroups to load the KV cache
2447
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
2448
+ //
2449
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
2450
+
2451
+ //int64_t nsgmax = 4;
2452
+ //
2453
+ //if (is_q) {
2454
+ // nsgmax = 2;
2455
+ // while (true) {
2456
+ // const size_t smem = FATTN_SMEM(nsgmax);
2457
+ // if (smem > props_dev->max_theadgroup_memory_size) {
2458
+ // break;
2459
+ // }
2460
+ // nsgmax *= 2;
2461
+ // }
2462
+ // nsgmax /= 2;
2463
+ //}
2464
+
2465
+ // simdgroups per threadgroup (a.k.a. warps)
2466
+ //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2467
+ int32_t nsg = 4;
2468
+
2469
+ const size_t smem = FATTN_SMEM(nsg);
2470
+
2471
+ ggml_metal_kargs_flash_attn_ext args = {
2472
+ /*.ne01 =*/ ne01,
2473
+ /*.ne02 =*/ ne02,
2474
+ /*.ne03 =*/ ne03,
2475
+ /*.nb01 =*/ nb01,
2476
+ /*.nb02 =*/ nb02,
2477
+ /*.nb03 =*/ nb03,
2478
+ /*.ne11 =*/ ne11,
2479
+ /*.ne_12_2 =*/ ne12,
2480
+ /*.ne_12_3 =*/ ne13,
2481
+ /*.ns10 =*/ int32_t(nb11/nb10),
2482
+ /*.nb11 =*/ nb11,
2483
+ /*.nb12 =*/ nb12,
2484
+ /*.nb13 =*/ nb13,
2485
+ /*.ns20 =*/ int32_t(nb21/nb20),
2486
+ /*.nb21 =*/ nb21,
2487
+ /*.nb22 =*/ nb22,
2488
+ /*.nb23 =*/ nb23,
2489
+ /*.ne31 =*/ ne31,
2490
+ /*.ne32 =*/ ne32,
2491
+ /*.ne33 =*/ ne33,
2492
+ /*.nb31 =*/ nb31,
2493
+ /*.nb32 =*/ nb32,
2494
+ /*.nb33 =*/ nb33,
2495
+ /*.ne1 =*/ ne1,
2496
+ /*.ne2 =*/ ne2,
2497
+ /*.ne3 =*/ ne3,
2498
+ /*.scale =*/ scale,
2499
+ /*.max_bias =*/ max_bias,
2500
+ /*.m0 =*/ m0,
2501
+ /*.m1 =*/ m1,
2502
+ /*.n_head_log2 =*/ n_head_log2,
2503
+ /*.logit_softcap =*/ logit_softcap,
2504
+ };
2505
+
2506
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2507
+
2508
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2509
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2510
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2511
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2512
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2513
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2514
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2515
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2516
+ ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2517
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2518
+
2519
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2520
+
2521
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
2522
+ #undef FATTN_SMEM
2523
+ } else {
2524
+ // half4x4 kernel
2525
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2526
+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2527
+ const int nkpsg = 1*ncpsg;
2528
+
2529
+ GGML_ASSERT(nqptg <= 32);
2530
+ GGML_ASSERT(nqptg % 1 == 0);
2531
+ GGML_ASSERT(ncpsg % 32 == 0);
2532
+
2533
+ bool need_sync = false;
2534
+
2535
+ const bool has_kvpad = ne11 % ncpsg != 0;
2536
+
2537
+ if (has_kvpad) {
2538
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2539
+
2540
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
2541
+ /*.ne11 =*/ne11,
2542
+ /*.ne_12_2 =*/ne12,
2543
+ /*.ne_12_3 =*/ne13,
2544
+ /*.nb11 =*/nb11,
2545
+ /*.nb12 =*/nb12,
2546
+ /*.nb13 =*/nb13,
2547
+ /*.nb21 =*/nb21,
2548
+ /*.nb22 =*/nb22,
2549
+ /*.nb23 =*/nb23,
2550
+ /*.ne31 =*/ne31,
2551
+ /*.ne32 =*/ne32,
2552
+ /*.ne33 =*/ne33,
2553
+ /*.nb31 =*/nb31,
2554
+ /*.nb32 =*/nb32,
2555
+ /*.nb33 =*/nb33,
2556
+ };
2557
+
2558
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2559
+
2560
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2561
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2562
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2563
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2564
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2565
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2566
+
2567
+ assert(ne12 == ne22);
2568
+ assert(ne13 == ne23);
2569
+
2570
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2571
+
2572
+ need_sync = true;
2573
+ }
2574
+
2575
+ if (need_sync) {
2576
+ ggml_metal_op_concurrency_reset(ctx);
2577
+ }
2578
+
2579
+ // ne00 + 2*ncpsg*(nsg)
2580
+ // for each query, we load it as f16 in shared memory (ne00)
2581
+ // and store the soft_max values and the mask
2582
+ //
2583
+ // ne20*(nsg)
2584
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
2585
+ //
2586
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
2587
+
2588
+ int64_t nsgmax = 2;
2589
+ while (true) {
2590
+ const size_t smem = FATTN_SMEM(nsgmax);
2591
+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
2592
+ if (smem > props_dev->max_theadgroup_memory_size/2) {
2593
+ break;
2594
+ }
2595
+ nsgmax *= 2;
2596
+ }
2597
+ nsgmax /= 2;
2598
+
2599
+ // simdgroups per threadgroup (a.k.a. warps)
2600
+ //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
2601
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
2602
+
2603
+ int64_t nsg = 1;
2604
+ while (nsg <= nsgt) {
2605
+ nsg *= 2;
2606
+ }
2607
+ nsg /= 2;
2608
+
2609
+ // workgroups
2610
+ // each workgroup handles nsg*nkpsg cache values
2611
+ int32_t nwg = 1;
2612
+ if (false) {
2613
+ // for small KV caches, we could launch a single workgroup and write the results directly to dst/
2614
+ // however, this does not lead to significant improvement, so disabled
2615
+ nwg = 1;
2616
+ nsg = 4;
2617
+ } else {
2618
+ nwg = 32;
2619
+ nsg = 1;
2620
+ while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
2621
+ nsg *= 2;
2622
+ }
2623
+ }
2624
+
2625
+ ggml_metal_kargs_flash_attn_ext_vec args = {
2626
+ /*.ne01 =*/ ne01,
2627
+ /*.ne02 =*/ ne02,
2628
+ /*.ne03 =*/ ne03,
2629
+ /*.nb01 =*/ nb01,
2630
+ /*.nb02 =*/ nb02,
2631
+ /*.nb03 =*/ nb03,
2632
+ /*.ne11 =*/ ne11,
2633
+ /*.ne_12_2 =*/ ne12,
2634
+ /*.ne_12_3 =*/ ne13,
2635
+ /*.ns10 =*/ int32_t(nb11/nb10),
2636
+ /*.nb11 =*/ nb11,
2637
+ /*.nb12 =*/ nb12,
2638
+ /*.nb13 =*/ nb13,
2639
+ /*.ns20 =*/ int32_t(nb21/nb20),
2640
+ /*.nb21 =*/ nb21,
2641
+ /*.nb22 =*/ nb22,
2642
+ /*.nb23 =*/ nb23,
2643
+ /*.ne31 =*/ ne31,
2644
+ /*.ne32 =*/ ne32,
2645
+ /*.ne33 =*/ ne33,
2646
+ /*.nb31 =*/ nb31,
2647
+ /*.nb32 =*/ nb32,
2648
+ /*.nb33 =*/ nb33,
2649
+ /*.ne1 =*/ ne1,
2650
+ /*.ne2 =*/ ne2,
2651
+ /*.ne3 =*/ ne3,
2652
+ /*.scale =*/ scale,
2653
+ /*.max_bias =*/ max_bias,
2654
+ /*.m0 =*/ m0,
2655
+ /*.m1 =*/ m1,
2656
+ /*.n_head_log2 =*/ n_head_log2,
2657
+ /*.logit_softcap =*/ logit_softcap,
2658
+ };
2659
+
2660
+ auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2661
+
2662
+ GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2663
+
2664
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2665
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2666
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2667
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2668
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2669
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2670
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2671
+
2672
+ const size_t smem = FATTN_SMEM(nsg);
2673
+
2674
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
2675
+ GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2676
+
2677
+ if (nwg == 1) {
2678
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2679
+
2680
+ // using 1 workgroup -> write the result directly into dst
2681
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2682
+ ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2683
+
2684
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2685
+
2686
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2687
+ } else {
2688
+ // sanity checks
2689
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2690
+
2691
+ GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2692
+ GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2693
+
2694
+ // write the results from each workgroup into a temp buffer
2695
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2696
+ ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2697
+
2698
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2699
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2700
+
2701
+ // sync the 2 kernels
2702
+ ggml_metal_op_concurrency_reset(ctx);
2703
+
2704
+ // reduce the results from the workgroups
2705
+ {
2706
+ const int32_t nrows = ne1*ne2*ne3;
2707
+
2708
+ ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
2709
+ nrows,
2710
+ };
2711
+
2712
+ auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2713
+
2714
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
2715
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2716
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
2717
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
2718
+
2719
+ ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1);
2720
+ }
2721
+ }
2722
+ #undef FATTN_SMEM
2723
+ }
2724
+
2725
+ return 1;
2726
+ }
2727
+
2728
+ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2729
+ ggml_tensor * op = ctx->node(idx);
2730
+
2731
+ ggml_metal_library_t lib = ctx->lib;
2732
+ ggml_metal_encoder_t enc = ctx->enc;
2733
+
2734
+ const bool use_fusion = ctx->use_fusion;
2735
+
2736
+ const int debug_fusion = ctx->debug_fusion;
2737
+
2738
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2739
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2740
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2741
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2742
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2743
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2744
+
2745
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2746
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2747
+
2748
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2749
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
2750
+
2751
+ bool bcast_row = false;
2752
+
2753
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2754
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2755
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2756
+
2757
+ ggml_metal_kargs_bin args = {
2758
+ /*.ne00 =*/ ne00,
2759
+ /*.ne01 =*/ ne01,
2760
+ /*.ne02 =*/ ne02,
2761
+ /*.ne03 =*/ ne03,
2762
+ /*.nb00 =*/ nb00,
2763
+ /*.nb01 =*/ nb01,
2764
+ /*.nb02 =*/ nb02,
2765
+ /*.nb03 =*/ nb03,
2766
+ /*.ne10 =*/ ne10,
2767
+ /*.ne11 =*/ ne11,
2768
+ /*.ne12 =*/ ne12,
2769
+ /*.ne13 =*/ ne13,
2770
+ /*.nb10 =*/ nb10,
2771
+ /*.nb11 =*/ nb11,
2772
+ /*.nb12 =*/ nb12,
2773
+ /*.nb13 =*/ nb13,
2774
+ /*.ne0 =*/ ne0,
2775
+ /*.ne1 =*/ ne1,
2776
+ /*.ne2 =*/ ne2,
2777
+ /*.ne3 =*/ ne3,
2778
+ /*.nb0 =*/ nb0,
2779
+ /*.nb1 =*/ nb1,
2780
+ /*.nb2 =*/ nb2,
2781
+ /*.nb3 =*/ nb3,
2782
+ /*.offs =*/ 0,
2783
+ /*.o1 =*/ { bid_src1.offs },
2784
+ };
2785
+
2786
+ ggml_op fops[8];
2787
+
2788
+ int n_fuse = 1;
2789
+
2790
+ // c[0] = add(a, b[0])
2791
+ // c[1] = add(c[0], b[1])
2792
+ // c[2] = add(c[1], b[2])
2793
+ // ...
2794
+ if (use_fusion) {
2795
+ fops[0] = GGML_OP_ADD;
2796
+ fops[1] = GGML_OP_ADD;
2797
+ fops[2] = GGML_OP_ADD;
2798
+ fops[3] = GGML_OP_ADD;
2799
+ fops[4] = GGML_OP_ADD;
2800
+ fops[5] = GGML_OP_ADD;
2801
+ fops[6] = GGML_OP_ADD;
2802
+ fops[7] = GGML_OP_ADD;
2803
+
2804
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
2805
+ // across splits. idx_end indicates the last node in the current split
2806
+ for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2807
+ if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
2808
+ break;
2809
+ }
2810
+
2811
+ ggml_tensor * f0 = ctx->node(idx + n_fuse);
2812
+ ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
2813
+
2814
+ if (f0 != f1->src[0]) {
2815
+ break;
2816
+ }
2817
+
2818
+ // b[0] === b[1] === ...
2819
+ if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
2820
+ break;
2821
+ }
2822
+
2823
+ // only fuse ops if src1 is in the same Metal buffer
2824
+ ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
2825
+ if (bid_fuse.metal != bid_src1.metal) {
2826
+ break;
2827
+ }
2828
+
2829
+ //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
2830
+
2831
+ args.o1[n_fuse + 1] = bid_fuse.offs;
2832
+ }
2833
+
2834
+ ++n_fuse;
2835
+
2836
+ if (debug_fusion > 1 && n_fuse > 1) {
2837
+ GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2838
+ }
2839
+ }
2840
+
2841
+ // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2842
+ bid_src1.offs = 0;
2843
+
2844
+ struct ggml_metal_pipeline_with_params pipeline;
2845
+
2846
+ if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2847
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
2848
+
2849
+ // src1 is a row
2850
+ GGML_ASSERT(ne11 == 1);
2851
+
2852
+ pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
2853
+
2854
+ bcast_row = true;
2855
+ } else {
2856
+ pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
2857
+ }
2858
+
2859
+ if (n_fuse > 1) {
2860
+ bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
2861
+
2862
+ for (int i = 1; i < n_fuse; ++i) {
2863
+ if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
2864
+ ggml_metal_op_concurrency_reset(ctx);
2865
+
2866
+ break;
2867
+ }
2868
+ }
2869
+ }
2870
+
2871
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2872
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2873
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2874
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2875
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
2876
+
2877
+ if (bcast_row) {
2878
+ const int64_t n = ggml_nelements(op)/4;
2879
+
2880
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
2881
+ } else {
2882
+ int nth = 32;
2883
+
2884
+ while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2885
+ nth *= 2;
2886
+ }
2887
+
2888
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2889
+ }
2890
+
2891
+ return n_fuse;
2892
+ }
2893
+
2894
+ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2895
+ ggml_tensor * op = ctx->node(idx);
2896
+
2897
+ ggml_metal_library_t lib = ctx->lib;
2898
+ ggml_metal_encoder_t enc = ctx->enc;
2899
+
2900
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2901
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2902
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2903
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2904
+
2905
+ float eps;
2906
+ memcpy(&eps, op->op_params, sizeof(float));
2907
+
2908
+ int nth = 32; // SIMD width
2909
+
2910
+ ggml_metal_kargs_l2_norm args = {
2911
+ /*.ne00 =*/ ne00,
2912
+ /*.ne00_4 =*/ ne00/4,
2913
+ /*.nb01 =*/ nb01,
2914
+ /*.eps =*/ eps,
2915
+ };
2916
+
2917
+ auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
2918
+
2919
+ while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2920
+ nth *= 2;
2921
+ }
2922
+
2923
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2924
+ nth = std::min(nth, ne00/4);
2925
+
2926
+ const size_t smem = pipeline.smem;
2927
+
2928
+ const int64_t nrows = ggml_nrows(op->src[0]);
2929
+
2930
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2931
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2932
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2933
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
2934
+
2935
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2936
+
2937
+ ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
2938
+
2939
+ return 1;
2940
+ }
2941
+
2942
+ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
2943
+ ggml_tensor * op = ctx->node(idx);
2944
+
2945
+ ggml_metal_library_t lib = ctx->lib;
2946
+ ggml_metal_encoder_t enc = ctx->enc;
2947
+
2948
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2949
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2950
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2951
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2952
+
2953
+ const int32_t ngrp = ((const int32_t *) op->op_params)[0];
2954
+
2955
+ float eps;
2956
+ memcpy(&eps, op->op_params + 1, sizeof(float));
2957
+
2958
+ ggml_metal_kargs_group_norm args = {
2959
+ /*.ne00 =*/ ne00,
2960
+ /*.ne01 =*/ ne01,
2961
+ /*.ne02 =*/ ne02,
2962
+ /*.nb00 =*/ nb00,
2963
+ /*.nb01 =*/ nb01,
2964
+ /*.nb02 =*/ nb02,
2965
+ /*.ngrp =*/ ngrp,
2966
+ /*.eps =*/ eps,
2967
+ };
2968
+
2969
+ auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
2970
+
2971
+ int nth = 32; // SIMD width
2972
+ //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2973
+ // nth *= 2;
2974
+ //}
2975
+
2976
+ //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2977
+ //nth = std::min(nth, ne00/4);
2978
+
2979
+ const size_t smem = pipeline.smem;
2980
+
2981
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
2982
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2983
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2984
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
2985
+
2986
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2987
+
2988
+ ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1);
2989
+
2990
+ return 1;
2991
+ }
2992
+
2993
+ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
2994
+ ggml_tensor * op = ctx->node(idx);
2995
+
2996
+ ggml_metal_library_t lib = ctx->lib;
2997
+ ggml_metal_encoder_t enc = ctx->enc;
2998
+
2999
+ const bool use_fusion = ctx->use_fusion;
3000
+
3001
+ const int debug_fusion = ctx->debug_fusion;
3002
+
3003
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3004
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3005
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3006
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3007
+
3008
+ float eps;
3009
+ memcpy(&eps, op->op_params, sizeof(float));
3010
+
3011
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3012
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3013
+
3014
+ ggml_metal_kargs_norm args = {
3015
+ /*.ne00 =*/ ne00,
3016
+ /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
3017
+ /*.nb1 =*/ nb1,
3018
+ /*.nb2 =*/ nb2,
3019
+ /*.nb3 =*/ nb3,
3020
+ /*.eps =*/ eps,
3021
+ /*.nef1 =*/ { ne01 },
3022
+ /*.nef2 =*/ { ne02 },
3023
+ /*.nef3 =*/ { ne03 },
3024
+ /*.nbf1 =*/ { nb01 },
3025
+ /*.nbf2 =*/ { nb02 },
3026
+ /*.nbf3 =*/ { nb03 },
3027
+ };
3028
+
3029
+ ggml_op fops[8];
3030
+
3031
+ int n_fuse = 1;
3032
+
3033
+ ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
3034
+
3035
+ // d[0] = norm(a)
3036
+ // d[1] = mul(d[0], b)
3037
+ // d[2] = add(d[1], c)
3038
+ if (use_fusion) {
3039
+ fops[0] = op->op;
3040
+ fops[1] = GGML_OP_MUL;
3041
+ fops[2] = GGML_OP_ADD;
3042
+
3043
+ for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
3044
+ if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
3045
+ break;
3046
+ }
3047
+
3048
+ ggml_tensor * f0 = ctx->node(idx + n_fuse);
3049
+ ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
3050
+
3051
+ if (f0 != f1->src[0]) {
3052
+ break;
3053
+ }
3054
+
3055
+ if (f1->src[1]->ne[0] != op->ne[0]) {
3056
+ break;
3057
+ }
3058
+
3059
+ if (!ggml_is_contiguous_rows(f1->src[1])) {
3060
+ break;
3061
+ }
3062
+
3063
+ if (f1->type != GGML_TYPE_F32) {
3064
+ break;
3065
+ }
3066
+
3067
+ //ctx->fuse_cnt[f1->op]++;
3068
+
3069
+ bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
3070
+
3071
+ args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
3072
+ args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
3073
+ args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
3074
+
3075
+ args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
3076
+ args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
3077
+ args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
3078
+ }
3079
+
3080
+ ++n_fuse;
3081
+
3082
+ if (debug_fusion > 1 && n_fuse > 1) {
3083
+ if (n_fuse == 2) {
3084
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
3085
+ }
3086
+ if (n_fuse == 3) {
3087
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
3088
+ }
3089
+ }
3090
+ }
3091
+
3092
+ if (n_fuse > 1) {
3093
+ bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
3094
+
3095
+ for (int i = 1; i < n_fuse; ++i) {
3096
+ if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
3097
+ ggml_metal_op_concurrency_reset(ctx);
3098
+
3099
+ break;
3100
+ }
3101
+ }
3102
+ }
3103
+
3104
+ auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3105
+
3106
+ int nth = 32; // SIMD width
3107
+
3108
+ while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3109
+ nth *= 2;
3110
+ }
3111
+
3112
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3113
+ nth = std::min(nth, args.ne00_t);
3114
+
3115
+ const size_t smem = pipeline.smem;
3116
+
3117
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3118
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3119
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3120
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
3121
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
3122
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
3123
+
3124
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3125
+
3126
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3127
+
3128
+ return n_fuse;
3129
+ }
3130
+
3131
+ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
3132
+ ggml_tensor * op = ctx->node(idx);
3133
+
3134
+ ggml_metal_library_t lib = ctx->lib;
3135
+ ggml_metal_encoder_t enc = ctx->enc;
3136
+
3137
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3138
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3139
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3140
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3141
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3142
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3143
+
3144
+ // make sure we have one or more position id(ne10) per token(ne02)
3145
+ GGML_ASSERT(ne10 % ne02 == 0);
3146
+ GGML_ASSERT(ne10 >= ne02);
3147
+
3148
+ const int nth = std::min(1024, ne00);
3149
+
3150
+ const int n_past = ((const int32_t *) op->op_params)[0];
3151
+ const int n_dims = ((const int32_t *) op->op_params)[1];
3152
+ //const int mode = ((const int32_t *) op->op_params)[2];
3153
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
3154
+ const int n_ctx_orig = ((const int32_t *) op->op_params)[4];
3155
+
3156
+ float freq_base;
3157
+ float freq_scale;
3158
+ float ext_factor;
3159
+ float attn_factor;
3160
+ float beta_fast;
3161
+ float beta_slow;
3162
+
3163
+ memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float));
3164
+ memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float));
3165
+ memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float));
3166
+ memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float));
3167
+ memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float));
3168
+ memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float));
3169
+
3170
+ // mrope
3171
+ const int sect_0 = ((const int32_t *) op->op_params)[11];
3172
+ const int sect_1 = ((const int32_t *) op->op_params)[12];
3173
+ const int sect_2 = ((const int32_t *) op->op_params)[13];
3174
+ const int sect_3 = ((const int32_t *) op->op_params)[14];
3175
+
3176
+ ggml_metal_kargs_rope args = {
3177
+ /*.ne00 =*/ ne00,
3178
+ /*.ne01 =*/ ne01,
3179
+ /*.ne02 =*/ ne02,
3180
+ /*.ne03 =*/ ne03,
3181
+ /*.nb00 =*/ nb00,
3182
+ /*.nb01 =*/ nb01,
3183
+ /*.nb02 =*/ nb02,
3184
+ /*.nb03 =*/ nb03,
3185
+ /*.ne0 =*/ ne0,
3186
+ /*.ne1 =*/ ne1,
3187
+ /*.ne2 =*/ ne2,
3188
+ /*.ne3 =*/ ne3,
3189
+ /*.nb0 =*/ nb0,
3190
+ /*.nb1 =*/ nb1,
3191
+ /*.nb2 =*/ nb2,
3192
+ /*.nb3 =*/ nb3,
3193
+ /*.n_past =*/ n_past,
3194
+ /*.n_dims =*/ n_dims,
3195
+ /*.n_ctx_orig =*/ n_ctx_orig,
3196
+ /*.freq_base =*/ freq_base,
3197
+ /*.freq_scale =*/ freq_scale,
3198
+ /*.ext_factor =*/ ext_factor,
3199
+ /*.attn_factor =*/ attn_factor,
3200
+ /*.beta_fast =*/ beta_fast,
3201
+ /*.beta_slow =*/ beta_slow,
3202
+ /* sect_0 =*/ sect_0,
3203
+ /* sect_1 =*/ sect_1,
3204
+ /* sect_2 =*/ sect_2,
3205
+ /* sect_3 =*/ sect_3,
3206
+ /* src2 =*/ op->src[2] != nullptr,
3207
+ };
3208
+
3209
+ auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
3210
+
3211
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3212
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3213
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3214
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3215
+ if (op->src[2]) {
3216
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
3217
+ } else {
3218
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3);
3219
+ }
3220
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
3221
+
3222
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3223
+
3224
+ return 1;
3225
+ }
3226
+
3227
+ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
3228
+ ggml_tensor * op = ctx->node(idx);
3229
+
3230
+ ggml_metal_library_t lib = ctx->lib;
3231
+ ggml_metal_encoder_t enc = ctx->enc;
3232
+
3233
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3234
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3235
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3236
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3237
+
3238
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3239
+ const int32_t s1 = ((const int32_t *)(op->op_params))[1];
3240
+ const int32_t p0 = ((const int32_t *)(op->op_params))[2];
3241
+ const int32_t p1 = ((const int32_t *)(op->op_params))[3];
3242
+ const int32_t d0 = ((const int32_t *)(op->op_params))[4];
3243
+ const int32_t d1 = ((const int32_t *)(op->op_params))[5];
3244
+
3245
+ const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
3246
+
3247
+ const int32_t N = op->src[1]->ne[is_2D ? 3 : 2];
3248
+ const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1];
3249
+ const int32_t IH = is_2D ? op->src[1]->ne[1] : 1;
3250
+ const int32_t IW = op->src[1]->ne[0];
3251
+
3252
+ const int32_t KH = is_2D ? op->src[0]->ne[1] : 1;
3253
+ const int32_t KW = op->src[0]->ne[0];
3254
+
3255
+ const int32_t OH = is_2D ? op->ne[2] : 1;
3256
+ const int32_t OW = op->ne[1];
3257
+
3258
+ const int32_t CHW = IC * KH * KW;
3259
+
3260
+ const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
3261
+ const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
3262
+
3263
+ ggml_metal_kargs_im2col args = {
3264
+ /*.ofs0 =*/ ofs0,
3265
+ /*.ofs1 =*/ ofs1,
3266
+ /*.IW =*/ IW,
3267
+ /*.IH =*/ IH,
3268
+ /*.CHW =*/ CHW,
3269
+ /*.s0 =*/ s0,
3270
+ /*.s1 =*/ s1,
3271
+ /*.p0 =*/ p0,
3272
+ /*.p1 =*/ p1,
3273
+ /*.d0 =*/ d0,
3274
+ /*.d1 =*/ d1,
3275
+ /*.N =*/ N,
3276
+ /*.KH =*/ KH,
3277
+ /*.KW =*/ KW,
3278
+ /*.KHW =*/ KH * KW,
3279
+ };
3280
+
3281
+ auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
3282
+
3283
+ GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3284
+
3285
+ const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3286
+
3287
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3288
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3289
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3290
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3291
+
3292
+ ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3293
+
3294
+ return 1;
3295
+ }
3296
+
3297
+ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
3298
+ ggml_tensor * op = ctx->node(idx);
3299
+
3300
+ ggml_metal_library_t lib = ctx->lib;
3301
+ ggml_metal_encoder_t enc = ctx->enc;
3302
+
3303
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3304
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3305
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3306
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3307
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3308
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3309
+
3310
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
3311
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
3312
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
3313
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
3314
+
3315
+ const int32_t s0 = ((const int32_t *) op->op_params)[0];
3316
+ const int32_t s1 = ((const int32_t *) op->op_params)[1];
3317
+ const int32_t p0 = ((const int32_t *) op->op_params)[2];
3318
+ const int32_t p1 = ((const int32_t *) op->op_params)[3];
3319
+ const int32_t d0 = ((const int32_t *) op->op_params)[4];
3320
+ const int32_t d1 = ((const int32_t *) op->op_params)[5];
3321
+
3322
+ ggml_metal_kargs_conv_2d args = {
3323
+ /*.nb00 =*/ nb00,
3324
+ /*.nb01 =*/ nb01,
3325
+ /*.nb02 =*/ nb02,
3326
+ /*.nb03 =*/ nb03,
3327
+ /*.nb10 =*/ nb10,
3328
+ /*.nb11 =*/ nb11,
3329
+ /*.nb12 =*/ nb12,
3330
+ /*.nb13 =*/ nb13,
3331
+ /*.nb0 =*/ nb0,
3332
+ /*.nb1 =*/ nb1,
3333
+ /*.nb2 =*/ nb2,
3334
+ /*.nb3 =*/ nb3,
3335
+ /*.IW =*/ ne10,
3336
+ /*.IH =*/ ne11,
3337
+ /*.KW =*/ ne00,
3338
+ /*.KH =*/ ne01,
3339
+ /*.IC =*/ ne02,
3340
+ /*.OC =*/ ne03,
3341
+ /*.OW =*/ ne0,
3342
+ /*.OH =*/ ne1,
3343
+ /*.N =*/ ne3,
3344
+ /*.s0 =*/ s0,
3345
+ /*.s1 =*/ s1,
3346
+ /*.p0 =*/ p0,
3347
+ /*.p1 =*/ p1,
3348
+ /*.d0 =*/ d0,
3349
+ /*.d1 =*/ d1,
3350
+ };
3351
+
3352
+ auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
3353
+
3354
+ int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3355
+ nth = std::min(nth, 256);
3356
+ nth = std::max(nth, 1);
3357
+
3358
+ const uint64_t n_out = ggml_nelements(op);
3359
+
3360
+ uint64_t tg = (n_out + nth - 1)/nth;
3361
+ tg = std::max<uint64_t>(tg, 1);
3362
+ tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3363
+
3364
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3365
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3366
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3367
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3368
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3369
+
3370
+ ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3371
+
3372
+ return 1;
3373
+ }
3374
+
3375
+ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
3376
+ ggml_tensor * op = ctx->node(idx);
3377
+
3378
+ ggml_metal_library_t lib = ctx->lib;
3379
+ ggml_metal_encoder_t enc = ctx->enc;
3380
+
3381
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3382
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3383
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3384
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3385
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3386
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3387
+
3388
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3389
+
3390
+ const int32_t IC = op->src[1]->ne[1];
3391
+ const int32_t IL = op->src[1]->ne[0];
3392
+
3393
+ const int32_t K = op->src[0]->ne[0];
3394
+
3395
+ const int32_t OL = op->ne[0];
3396
+ const int32_t OC = op->ne[1];
3397
+
3398
+ ggml_metal_kargs_conv_transpose_1d args = {
3399
+ /*.IC =*/ IC,
3400
+ /*.IL =*/ IL,
3401
+ /*.K =*/ K,
3402
+ /*.s0 =*/ s0,
3403
+ /*.nb0 =*/ nb0,
3404
+ /*.nb1 =*/ nb1,
3405
+ };
3406
+
3407
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3408
+
3409
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3410
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3411
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3412
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3413
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3414
+
3415
+ ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
3416
+
3417
+ return 1;
3418
+ }
3419
+
3420
+ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
3421
+ ggml_tensor * op = ctx->node(idx);
3422
+
3423
+ ggml_metal_library_t lib = ctx->lib;
3424
+ ggml_metal_encoder_t enc = ctx->enc;
3425
+
3426
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3427
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3428
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3429
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3430
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3431
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3432
+
3433
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3434
+
3435
+ const int32_t IC = op->src[1]->ne[2];
3436
+ const int32_t IH = op->src[1]->ne[1];
3437
+ const int32_t IW = op->src[1]->ne[0];
3438
+
3439
+ const int32_t KH = op->src[0]->ne[1];
3440
+ const int32_t KW = op->src[0]->ne[0];
3441
+
3442
+ const int32_t OW = op->ne[0];
3443
+ const int32_t OH = op->ne[1];
3444
+ const int32_t OC = op->ne[2];
3445
+
3446
+ ggml_metal_kargs_conv_transpose_2d args = {
3447
+ /*.IC =*/ IC,
3448
+ /*.IH =*/ IH,
3449
+ /*.IW =*/ IW,
3450
+ /*.KH =*/ KH,
3451
+ /*.KW =*/ KW,
3452
+ /*.OC =*/ OC,
3453
+ /*.s0 =*/ s0,
3454
+ /*.nb0 =*/ nb0,
3455
+ /*.nb1 =*/ nb1,
3456
+ /*.nb2 =*/ nb2,
3457
+ };
3458
+
3459
+ auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3460
+
3461
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3462
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3463
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3464
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3465
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3466
+
3467
+ // Metal requires buffer size to be multiple of 16 bytes
3468
+ const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3469
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3470
+
3471
+ ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3472
+
3473
+ return 1;
3474
+ }
3475
+
3476
+ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
3477
+ ggml_tensor * op = ctx->node(idx);
3478
+
3479
+ ggml_metal_library_t lib = ctx->lib;
3480
+ ggml_metal_encoder_t enc = ctx->enc;
3481
+
3482
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3483
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3484
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3485
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3486
+
3487
+ const float sf0 = (float)ne0/op->src[0]->ne[0];
3488
+ const float sf1 = (float)ne1/op->src[0]->ne[1];
3489
+ const float sf2 = (float)ne2/op->src[0]->ne[2];
3490
+ const float sf3 = (float)ne3/op->src[0]->ne[3];
3491
+
3492
+ ggml_metal_kargs_upscale args = {
3493
+ /*.ne00 =*/ ne00,
3494
+ /*.ne01 =*/ ne01,
3495
+ /*.ne02 =*/ ne02,
3496
+ /*.ne03 =*/ ne03,
3497
+ /*.nb00 =*/ nb00,
3498
+ /*.nb01 =*/ nb01,
3499
+ /*.nb02 =*/ nb02,
3500
+ /*.nb03 =*/ nb03,
3501
+ /*.ne0 =*/ ne0,
3502
+ /*.ne1 =*/ ne1,
3503
+ /*.ne2 =*/ ne2,
3504
+ /*.ne3 =*/ ne3,
3505
+ /*.nb0 =*/ nb0,
3506
+ /*.nb1 =*/ nb1,
3507
+ /*.nb2 =*/ nb2,
3508
+ /*.nb3 =*/ nb3,
3509
+ /*.sf0 =*/ sf0,
3510
+ /*.sf1 =*/ sf1,
3511
+ /*.sf2 =*/ sf2,
3512
+ /*.sf3 =*/ sf3
3513
+ };
3514
+
3515
+ auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
3516
+
3517
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3518
+
3519
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3520
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3521
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3522
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3523
+
3524
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3525
+
3526
+ return 1;
3527
+ }
3528
+
3529
+ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
3530
+ ggml_tensor * op = ctx->node(idx);
3531
+
3532
+ ggml_metal_library_t lib = ctx->lib;
3533
+ ggml_metal_encoder_t enc = ctx->enc;
3534
+
3535
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3536
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3537
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3538
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3539
+
3540
+ ggml_metal_kargs_pad args = {
3541
+ /*.ne00 =*/ ne00,
3542
+ /*.ne01 =*/ ne01,
3543
+ /*.ne02 =*/ ne02,
3544
+ /*.ne03 =*/ ne03,
3545
+ /*.nb00 =*/ nb00,
3546
+ /*.nb01 =*/ nb01,
3547
+ /*.nb02 =*/ nb02,
3548
+ /*.nb03 =*/ nb03,
3549
+ /*.ne0 =*/ ne0,
3550
+ /*.ne1 =*/ ne1,
3551
+ /*.ne2 =*/ ne2,
3552
+ /*.ne3 =*/ ne3,
3553
+ /*.nb0 =*/ nb0,
3554
+ /*.nb1 =*/ nb1,
3555
+ /*.nb2 =*/ nb2,
3556
+ /*.nb3 =*/ nb3
3557
+ };
3558
+
3559
+ auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
3560
+
3561
+ const int nth = std::min(1024, ne0);
3562
+
3563
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3564
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3565
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3566
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3567
+
3568
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3569
+
3570
+ return 1;
3571
+ }
3572
+
3573
+ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
3574
+ ggml_tensor * op = ctx->node(idx);
3575
+
3576
+ ggml_metal_library_t lib = ctx->lib;
3577
+ ggml_metal_encoder_t enc = ctx->enc;
3578
+
3579
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3580
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3581
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3582
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3583
+
3584
+ ggml_metal_kargs_pad_reflect_1d args = {
3585
+ /*.ne00 =*/ ne00,
3586
+ /*.ne01 =*/ ne01,
3587
+ /*.ne02 =*/ ne02,
3588
+ /*.ne03 =*/ ne03,
3589
+ /*.nb00 =*/ nb00,
3590
+ /*.nb01 =*/ nb01,
3591
+ /*.nb02 =*/ nb02,
3592
+ /*.nb03 =*/ nb03,
3593
+ /*.ne0 =*/ ne0,
3594
+ /*.ne1 =*/ ne1,
3595
+ /*.ne2 =*/ ne2,
3596
+ /*.ne3 =*/ ne3,
3597
+ /*.nb0 =*/ nb0,
3598
+ /*.nb1 =*/ nb1,
3599
+ /*.nb2 =*/ nb2,
3600
+ /*.nb3 =*/ nb3,
3601
+ /*.p0 =*/ ((const int32_t *)(op->op_params))[0],
3602
+ /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
3603
+ };
3604
+
3605
+ auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3606
+
3607
+ const int nth = std::min(1024, ne0);
3608
+
3609
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3610
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3611
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3612
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3613
+
3614
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3615
+
3616
+ return 1;
3617
+ }
3618
+
3619
+ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
3620
+ ggml_tensor * op = ctx->node(idx);
3621
+
3622
+ ggml_metal_library_t lib = ctx->lib;
3623
+ ggml_metal_encoder_t enc = ctx->enc;
3624
+
3625
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3626
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3627
+
3628
+ float start;
3629
+ float step;
3630
+
3631
+ memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float));
3632
+ memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float));
3633
+
3634
+ ggml_metal_kargs_arange args = {
3635
+ /*.ne0 =*/ ne0,
3636
+ /*.start =*/ start,
3637
+ /*.step =*/ step
3638
+ };
3639
+
3640
+ const int nth = std::min(1024, ne0);
3641
+
3642
+ auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
3643
+
3644
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3645
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3646
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
3647
+
3648
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
3649
+
3650
+ return 1;
3651
+ }
3652
+
3653
+ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3654
+ ggml_tensor * op = ctx->node(idx);
3655
+
3656
+ ggml_metal_library_t lib = ctx->lib;
3657
+ ggml_metal_encoder_t enc = ctx->enc;
3658
+
3659
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3660
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3661
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3662
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3663
+
3664
+ const int dim = op->op_params[0];
3665
+ const int max_period = op->op_params[1];
3666
+
3667
+ ggml_metal_kargs_timestep_embedding args = {
3668
+ /*.nb1 =*/ nb1,
3669
+ /*.dim =*/ dim,
3670
+ /*.max_period =*/ max_period,
3671
+ };
3672
+
3673
+ auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3674
+
3675
+ const int nth = std::max(1, std::min(1024, dim/2));
3676
+
3677
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3678
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3679
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3680
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3681
+
3682
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1);
3683
+
3684
+ return 1;
3685
+ }
3686
+
3687
+ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3688
+ ggml_tensor * op = ctx->node(idx);
3689
+
3690
+ ggml_metal_library_t lib = ctx->lib;
3691
+ ggml_metal_encoder_t enc = ctx->enc;
3692
+
3693
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3694
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3695
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3696
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3697
+
3698
+ ggml_metal_kargs_argmax args = {
3699
+ /*.ne00 = */ ne00,
3700
+ /*.nb01 = */ nb01,
3701
+ };
3702
+
3703
+ auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
3704
+
3705
+ const int64_t nrows = ggml_nrows(op->src[0]);
3706
+
3707
+ int nth = 32; // SIMD width
3708
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
3709
+ nth *= 2;
3710
+ }
3711
+
3712
+ const size_t smem = pipeline.smem;
3713
+
3714
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3715
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3716
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3717
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3718
+
3719
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3720
+
3721
+ ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3722
+
3723
+ return 1;
3724
+ }
3725
+
3726
+ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
3727
+ ggml_tensor * op = ctx->node(idx);
3728
+
3729
+ ggml_metal_library_t lib = ctx->lib;
3730
+ ggml_metal_encoder_t enc = ctx->enc;
3731
+
3732
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3733
+
3734
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3735
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3736
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3737
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3738
+
3739
+ auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3740
+
3741
+ // bitonic sort requires the number of elements to be power of 2
3742
+ int nth = 1;
3743
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3744
+ nth *= 2;
3745
+ }
3746
+
3747
+ const int npr = (ne00 + nth - 1)/nth;
3748
+
3749
+ // Metal kernels require the buffer size to be multiple of 16 bytes
3750
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3751
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3752
+
3753
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3754
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3755
+
3756
+ ggml_metal_buffer_id bid_tmp = bid_dst;
3757
+ bid_tmp.offs += ggml_nbytes(op);
3758
+
3759
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3760
+ std::swap(bid_dst, bid_tmp);
3761
+ }
3762
+
3763
+ ggml_metal_kargs_argsort args = {
3764
+ /*.ne00 =*/ ne00,
3765
+ /*.ne01 =*/ ne01,
3766
+ /*.ne02 =*/ ne02,
3767
+ /*.ne03 =*/ ne03,
3768
+ /*.nb00 =*/ nb00,
3769
+ /*.nb01 =*/ nb01,
3770
+ /*.nb02 =*/ nb02,
3771
+ /*.nb03 =*/ nb03,
3772
+ /*.ne0 =*/ ne0,
3773
+ /*.ne1 =*/ ne1,
3774
+ /*.ne2 =*/ ne2,
3775
+ /*.ne3 =*/ ne3,
3776
+ /*.top_k =*/ nth,
3777
+ };
3778
+
3779
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3780
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3781
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3782
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3783
+
3784
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3785
+
3786
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3787
+
3788
+ auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3789
+
3790
+ int len = nth;
3791
+
3792
+ while (len < ne00) {
3793
+ ggml_metal_op_concurrency_reset(ctx);
3794
+
3795
+ ggml_metal_kargs_argsort_merge args_merge = {
3796
+ /*.ne00 =*/ ne00,
3797
+ /*.ne01 =*/ ne01,
3798
+ /*.ne02 =*/ ne02,
3799
+ /*.ne03 =*/ ne03,
3800
+ /*.nb00 =*/ nb00,
3801
+ /*.nb01 =*/ nb01,
3802
+ /*.nb02 =*/ nb02,
3803
+ /*.nb03 =*/ nb03,
3804
+ /*.ne0 =*/ ne0,
3805
+ /*.ne1 =*/ ne1,
3806
+ /*.ne2 =*/ ne2,
3807
+ /*.ne3 =*/ ne3,
3808
+ /*.top_k =*/ ne00,
3809
+ /*.len =*/ len,
3810
+ };
3811
+
3812
+ // merges per row
3813
+ const int nm = (ne00 + 2*len - 1) / (2*len);
3814
+
3815
+ const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3816
+
3817
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3818
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3819
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3820
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3821
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3822
+
3823
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3824
+
3825
+ std::swap(bid_dst, bid_tmp);
3826
+
3827
+ len <<= 1;
3828
+ }
3829
+
3830
+ return 1;
3831
+ }
3832
+
3833
+ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3834
+ ggml_tensor * op = ctx->node(idx);
3835
+
3836
+ ggml_metal_library_t lib = ctx->lib;
3837
+ ggml_metal_encoder_t enc = ctx->enc;
3838
+
3839
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3840
+
3841
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3842
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3843
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3844
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3845
+
3846
+ auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
3847
+
3848
+ // bitonic sort requires the number of elements to be power of 2
3849
+ int nth = 1;
3850
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3851
+ nth *= 2;
3852
+ }
3853
+
3854
+ // blocks per row
3855
+ const int npr = (ne00 + nth - 1)/nth;
3856
+
3857
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3858
+
3859
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3860
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3861
+
3862
+ ggml_metal_buffer_id bid_tmp = bid_dst;
3863
+ bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
3864
+
3865
+ if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3866
+ std::swap(bid_dst, bid_tmp);
3867
+ }
3868
+
3869
+ const int top_k = ne0;
3870
+
3871
+ ggml_metal_kargs_argsort args = {
3872
+ /*.ne00 =*/ ne00,
3873
+ /*.ne01 =*/ ne01,
3874
+ /*.ne02 =*/ ne02,
3875
+ /*.ne03 =*/ ne03,
3876
+ /*.nb00 =*/ nb00,
3877
+ /*.nb01 =*/ nb01,
3878
+ /*.nb02 =*/ nb02,
3879
+ /*.nb03 =*/ nb03,
3880
+ /*.ne0 =*/ ne0,
3881
+ /*.ne1 =*/ ne1,
3882
+ /*.ne2 =*/ ne2,
3883
+ /*.ne3 =*/ ne3,
3884
+ /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
3885
+ };
3886
+
3887
+ if (npr > 1) {
3888
+ args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
3889
+ }
3890
+
3891
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3892
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3893
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3894
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3895
+
3896
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3897
+
3898
+ ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3899
+
3900
+ auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
3901
+
3902
+ int len = args.top_k;
3903
+
3904
+ while (len < args.ne0) {
3905
+ ggml_metal_op_concurrency_reset(ctx);
3906
+
3907
+ // merges per row
3908
+ const int nm = (args.ne0 + 2*len - 1) / (2*len);
3909
+
3910
+ const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
3911
+
3912
+ ggml_metal_kargs_argsort_merge args_merge = {
3913
+ /*.ne00 =*/ ne00,
3914
+ /*.ne01 =*/ ne01,
3915
+ /*.ne02 =*/ ne02,
3916
+ /*.ne03 =*/ ne03,
3917
+ /*.nb00 =*/ nb00,
3918
+ /*.nb01 =*/ nb01,
3919
+ /*.nb02 =*/ nb02,
3920
+ /*.nb03 =*/ nb03,
3921
+ /*.ne0 =*/ args.ne0,
3922
+ /*.ne1 =*/ ne1,
3923
+ /*.ne2 =*/ ne2,
3924
+ /*.ne3 =*/ ne3,
3925
+ /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
3926
+ /*.len =*/ len,
3927
+ };
3928
+
3929
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3930
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3931
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3932
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3933
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3934
+
3935
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3936
+
3937
+ std::swap(bid_dst, bid_tmp);
3938
+
3939
+ len <<= 1;
3940
+ }
3941
+
3942
+ return 1;
3943
+ }
3944
+
3945
+ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
3946
+ ggml_tensor * op = ctx->node(idx);
3947
+
3948
+ ggml_metal_library_t lib = ctx->lib;
3949
+ ggml_metal_encoder_t enc = ctx->enc;
3950
+
3951
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3952
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3953
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3954
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3955
+
3956
+ float slope;
3957
+ memcpy(&slope, op->op_params, sizeof(float));
3958
+
3959
+ ggml_metal_kargs_leaky_relu args = {
3960
+ /*.slope =*/ slope
3961
+ };
3962
+
3963
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
3964
+
3965
+ int64_t n = ggml_nelements(op);
3966
+
3967
+ if (n % 4 == 0) {
3968
+ n /= 4;
3969
+ }
3970
+
3971
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3972
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3973
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3974
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3975
+
3976
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
3977
+
3978
+ return 1;
3979
+ }
3980
+
3981
+ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
3982
+ ggml_tensor * op = ctx->node(idx);
3983
+
3984
+ ggml_metal_library_t lib = ctx->lib;
3985
+ ggml_metal_encoder_t enc = ctx->enc;
3986
+
3987
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3988
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3989
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3990
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3991
+
3992
+ ggml_metal_kargs_tri args = {
3993
+ /*.ne00 =*/ ne00,
3994
+ /*.ne01 =*/ ne01,
3995
+ /*.ne02 =*/ ne02,
3996
+ /*.ne03 =*/ ne03,
3997
+ /*.nb00 =*/ nb00,
3998
+ /*.nb01 =*/ nb01,
3999
+ /*.nb02 =*/ nb02,
4000
+ /*.nb03 =*/ nb03,
4001
+ /*.ne0 =*/ ne0,
4002
+ /*.ne1 =*/ ne1,
4003
+ /*.ne2 =*/ ne2,
4004
+ /*.ne3 =*/ ne3,
4005
+ /*.nb0 =*/ nb0,
4006
+ /*.nb1 =*/ nb1,
4007
+ /*.nb2 =*/ nb2,
4008
+ /*.nb3 =*/ nb3,
4009
+ };
4010
+
4011
+ auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
4012
+
4013
+ int nth = 32; // SIMD width
4014
+
4015
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4016
+ nth *= 2;
4017
+ }
4018
+
4019
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4020
+ nth = std::min(nth, ne00);
4021
+
4022
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4023
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
4024
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4025
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
4026
+
4027
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4028
+
4029
+ return 1;
4030
+ }
4031
+
4032
+ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
4033
+ ggml_tensor * op = ctx->node(idx);
4034
+
4035
+ ggml_metal_library_t lib = ctx->lib;
4036
+ ggml_metal_encoder_t enc = ctx->enc;
4037
+
4038
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4039
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4040
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4041
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4042
+
4043
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
4044
+
4045
+ const int64_t np = ggml_nelements(op->src[0]);
4046
+ ggml_metal_kargs_opt_step_adamw args = {
4047
+ /*.np =*/ np,
4048
+ };
4049
+
4050
+ int ida = 0;
4051
+
4052
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4053
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4054
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4055
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4056
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4057
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
4058
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
4059
+
4060
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4061
+ const int64_t n = (np + nth - 1) / nth;
4062
+
4063
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4064
+
4065
+ return 1;
4066
+ }
4067
+
4068
+ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
4069
+ ggml_tensor * op = ctx->node(idx);
4070
+
4071
+ ggml_metal_library_t lib = ctx->lib;
4072
+ ggml_metal_encoder_t enc = ctx->enc;
4073
+
4074
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4075
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4076
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4077
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4078
+
4079
+ auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
4080
+
4081
+ const int64_t np = ggml_nelements(op->src[0]);
4082
+ ggml_metal_kargs_opt_step_sgd args = {
4083
+ /*.np =*/ np,
4084
+ };
4085
+
4086
+ int ida = 0;
4087
+
4088
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4089
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4090
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4091
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4092
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4093
+
4094
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4095
+ const int64_t n = (np + nth - 1) / nth;
4096
+
4097
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4098
+
4099
+ return 1;
4100
+ }
4101
+
4102
+ int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
4103
+ ggml_tensor * op = ctx->node(idx);
4104
+
4105
+ ggml_metal_library_t lib = ctx->lib;
4106
+ ggml_metal_encoder_t enc = ctx->enc;
4107
+
4108
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
4109
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4110
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
4111
+
4112
+ {
4113
+ ggml_metal_kargs_memset args = { /*.val =*/ 0 };
4114
+
4115
+ auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
4116
+
4117
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4118
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4119
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
4120
+
4121
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
4122
+ }
4123
+
4124
+ ggml_metal_op_concurrency_reset(ctx);
4125
+
4126
+ {
4127
+ ggml_metal_kargs_count_equal args = {
4128
+ /*.ne00 =*/ ne00,
4129
+ /*.ne01 =*/ ne01,
4130
+ /*.ne02 =*/ ne02,
4131
+ /*.ne03 =*/ ne03,
4132
+ /*.nb00 =*/ nb00,
4133
+ /*.nb01 =*/ nb01,
4134
+ /*.nb02 =*/ nb02,
4135
+ /*.nb03 =*/ nb03,
4136
+ /*.nb10 =*/ nb10,
4137
+ /*.nb11 =*/ nb11,
4138
+ /*.nb12 =*/ nb12,
4139
+ /*.nb13 =*/ nb13,
4140
+ };
4141
+
4142
+ auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
4143
+
4144
+ const size_t smem = pipeline.smem;
4145
+
4146
+ const int nth = 32*pipeline.nsg;
4147
+
4148
+ GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4149
+
4150
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
4151
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4152
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4153
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
4154
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
4155
+
4156
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4157
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4158
+ }
4159
+
4160
+ return 1;
4161
+ }