whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -19,7 +19,17 @@
19
19
  // max number of MTLCommandBuffer used to submit a graph for processing
20
20
  #define GGML_METAL_MAX_COMMAND_BUFFERS 8
21
21
 
22
- #define UNUSED(x) (void)(x)
22
+ #ifndef TARGET_OS_VISION
23
+ #define TARGET_OS_VISION 0
24
+ #endif
25
+
26
+ // create residency sets only on macOS >= 15.0
27
+ #if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
28
+ TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
29
+ TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
30
+ TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
31
+ #define GGML_METAL_HAS_RESIDENCY_SETS 1
32
+ #endif
23
33
 
24
34
  // globals
25
35
 
@@ -34,11 +44,13 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
34
44
  // note: assumes single GPU device - the default one
35
45
  // TODO: support multiple GPU devices
36
46
  static struct ggml_backend_metal_device_context {
37
- id<MTLDevice> mtl_device;
38
- int mtl_device_ref_count;
47
+ id<MTLDevice> mtl_device;
48
+ int mtl_device_ref_count;
49
+ id<MTLLibrary> mtl_library;
39
50
 
40
51
  bool has_simdgroup_reduction;
41
52
  bool has_simdgroup_mm;
53
+ bool has_residency_sets;
42
54
  bool has_bfloat;
43
55
  bool use_bfloat;
44
56
 
@@ -46,8 +58,10 @@ static struct ggml_backend_metal_device_context {
46
58
  } g_ggml_ctx_dev_main = {
47
59
  /*.mtl_device =*/ nil,
48
60
  /*.mtl_device_ref_count =*/ 0,
61
+ /*.mtl_library =*/ nil,
49
62
  /*.has_simdgroup_reduction =*/ false,
50
63
  /*.has_simdgroup_mm =*/ false,
64
+ /*.has_residency_sets =*/ false,
51
65
  /*.has_bfloat =*/ false,
52
66
  /*.use_bfloat =*/ false,
53
67
  /*.name =*/ "",
@@ -59,12 +73,18 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
59
73
 
60
74
  if (ctx->mtl_device == nil) {
61
75
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
76
+ }
62
77
 
78
+ if (ctx->mtl_device) {
63
79
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
64
80
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
65
81
 
66
82
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
67
83
 
84
+ #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
85
+ ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
86
+ #endif
87
+
68
88
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
69
89
  ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
70
90
 
@@ -90,8 +110,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
90
110
  ctx->mtl_device_ref_count--;
91
111
 
92
112
  if (ctx->mtl_device_ref_count == 0) {
93
- [ctx->mtl_device release];
94
- ctx->mtl_device = nil;
113
+ if (ctx->mtl_library) {
114
+ [ctx->mtl_library release];
115
+ ctx->mtl_library = nil;
116
+ }
117
+
118
+ if (ctx->mtl_device) {
119
+ [ctx->mtl_device release];
120
+ ctx->mtl_device = nil;
121
+ }
95
122
  }
96
123
  }
97
124
 
@@ -122,6 +149,8 @@ enum ggml_metal_kernel_type {
122
149
  GGML_METAL_KERNEL_TYPE_SIGMOID,
123
150
  GGML_METAL_KERNEL_TYPE_GELU,
124
151
  GGML_METAL_KERNEL_TYPE_GELU_4,
152
+ GGML_METAL_KERNEL_TYPE_GELU_ERF,
153
+ GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
125
154
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
126
155
  GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
127
156
  GGML_METAL_KERNEL_TYPE_SILU,
@@ -157,10 +186,13 @@ enum ggml_metal_kernel_type {
157
186
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
158
187
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
159
188
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
189
+ GGML_METAL_KERNEL_TYPE_L2_NORM,
160
190
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
161
191
  GGML_METAL_KERNEL_TYPE_NORM,
162
192
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
163
193
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
194
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
195
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
164
196
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
165
197
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
166
198
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -276,30 +308,36 @@ enum ggml_metal_kernel_type {
276
308
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
277
309
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
278
310
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
279
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
280
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
281
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
282
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
283
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
284
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
285
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
286
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
287
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
288
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
289
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
290
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
291
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
292
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
293
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
294
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
295
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
296
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
297
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
298
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
299
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
300
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
311
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
312
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
313
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
314
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
315
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
316
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
317
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
318
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
319
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
320
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
321
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
322
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
323
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
324
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
325
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
326
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
327
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
328
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
329
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
330
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
331
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
332
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
333
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
334
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
301
335
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
302
336
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
337
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
338
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
339
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
340
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
303
341
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
304
342
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
305
343
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -321,43 +359,78 @@ enum ggml_metal_kernel_type {
321
359
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
322
360
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
323
361
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
362
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
363
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
324
364
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
365
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
325
366
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
326
367
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
327
368
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
328
369
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
329
370
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
371
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
372
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
330
373
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
374
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
331
375
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
332
376
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
333
377
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
334
378
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
335
379
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
380
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
381
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
336
382
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
383
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
337
384
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
338
385
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
339
386
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
340
387
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
341
388
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
389
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
390
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
342
391
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
392
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
343
393
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
344
394
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
345
395
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
346
396
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
347
397
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
398
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
399
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
348
400
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
401
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
349
402
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
350
403
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
351
404
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
352
405
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
353
406
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
407
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
408
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
354
409
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
410
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
355
411
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
356
412
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
357
413
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
358
414
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
359
415
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
416
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
417
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
360
418
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
419
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
420
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
421
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
422
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
423
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
424
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
425
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
426
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
427
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
428
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
429
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
430
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
431
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
432
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
433
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
361
434
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
362
435
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
363
436
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
@@ -365,6 +438,20 @@ enum ggml_metal_kernel_type {
365
438
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
366
439
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
367
440
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
441
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
442
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
443
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
444
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
445
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
446
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
447
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
448
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
449
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
450
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
451
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
452
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
453
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
454
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
368
455
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
369
456
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
370
457
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
@@ -372,6 +459,13 @@ enum ggml_metal_kernel_type {
372
459
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
373
460
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
374
461
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
462
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
463
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
464
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
465
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
466
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
467
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
468
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
375
469
  GGML_METAL_KERNEL_TYPE_SET_I32,
376
470
  GGML_METAL_KERNEL_TYPE_SET_F32,
377
471
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -387,11 +481,22 @@ enum ggml_metal_kernel_type {
387
481
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
388
482
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
389
483
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
484
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
485
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
486
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
487
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
488
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
489
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
490
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
491
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
492
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
493
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
390
494
  GGML_METAL_KERNEL_TYPE_CONCAT,
391
495
  GGML_METAL_KERNEL_TYPE_SQR,
392
496
  GGML_METAL_KERNEL_TYPE_SQRT,
393
497
  GGML_METAL_KERNEL_TYPE_SIN,
394
498
  GGML_METAL_KERNEL_TYPE_COS,
499
+ GGML_METAL_KERNEL_TYPE_NEG,
395
500
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
396
501
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
397
502
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -400,7 +505,264 @@ enum ggml_metal_kernel_type {
400
505
  GGML_METAL_KERNEL_TYPE_COUNT
401
506
  };
402
507
 
508
+ //
509
+ // ggml_metal_heap
510
+ //
511
+
512
+ struct ggml_metal_heap {
513
+ // number of times the heap was unused
514
+ int n_unused;
515
+
516
+ // total number of buffer allocations in this heap across all computes
517
+ int64_t n_alloc;
518
+
519
+ // current offset in the heap - we reset this after each node in order to reuse the memory
520
+ size_t offs;
521
+
522
+ // the currently allocated MTLBuffer objects in this heap
523
+ id<MTLHeap> obj;
524
+
525
+ NSMutableArray * bufs;
526
+ };
527
+
528
+ static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
529
+ struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
530
+
531
+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
532
+ desc.storageMode = MTLStorageModePrivate;
533
+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
534
+ desc.type = MTLHeapTypePlacement;
535
+ desc.size = size;
536
+
537
+ heap->n_unused = 0;
538
+ heap->n_alloc = 0;
539
+
540
+ heap->obj = [device newHeapWithDescriptor:desc];
541
+ if (!heap->obj) {
542
+ GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
543
+
544
+ free(heap);
545
+
546
+ return false;
547
+ }
548
+
549
+ [desc release];
550
+
551
+ heap->bufs = [[NSMutableArray alloc] init];
552
+
553
+ return heap;
554
+ }
555
+
556
+ static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
557
+ heap->offs = 0;
558
+
559
+ // count how many graph computes the heap ended up being unused
560
+ if ([heap->bufs count] > 0) {
561
+ heap->n_unused = 0;
562
+ } else {
563
+ heap->n_unused++;
564
+ }
565
+
566
+ for (id<MTLBuffer> buf in heap->bufs) {
567
+ [buf release];
568
+ }
569
+ [heap->bufs removeAllObjects];
570
+
571
+ // tell the OS that it can reuse this memory if needed
572
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
573
+ [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
574
+ }
575
+
576
+ static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
577
+ if (heap == nil) {
578
+ return;
579
+ }
580
+
581
+ ggml_metal_heap_reset(heap);
582
+
583
+ [heap->obj release];
584
+ [heap->bufs release];
585
+
586
+ free(heap);
587
+ }
588
+
589
+ @interface ggml_metal_heap_ptr : NSObject
590
+
591
+ @property (nonatomic, assign) struct ggml_metal_heap * data;
592
+
593
+ @end
594
+
595
+ @implementation ggml_metal_heap_ptr
596
+ @end
597
+
598
+ //
599
+ // ggml_metal_mem_pool
600
+ //
601
+
602
+ struct ggml_metal_mem_pool {
603
+ id<MTLDevice> device;
604
+
605
+ int n_heaps; // total number of heaps ever created (including those that were removed)
606
+
607
+ NSMutableArray * heaps;
608
+ NSMutableArray * heaps_to_remove;
609
+ };
610
+
611
+ static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
612
+ struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
613
+
614
+ mem_pool->n_heaps = 0;
615
+
616
+ mem_pool->heaps = [[NSMutableArray alloc] init];
617
+ mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
618
+
619
+ return mem_pool;
620
+ }
621
+
622
+ static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
623
+ GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
624
+
625
+ size_t size_all = 0;
626
+ size_t size_cur = 0;
627
+
628
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
629
+ GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
630
+ GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
631
+ GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
632
+ GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
633
+ GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
634
+
635
+ if ([ptr.data->bufs count] > 0) {
636
+ size_cur += [ptr.data->obj size];
637
+ }
638
+ size_all += [ptr.data->obj size];
639
+
640
+ ggml_metal_heap_free(ptr.data);
641
+ [ptr release];
642
+ }
643
+ [mem_pool->heaps release];
644
+ [mem_pool->heaps_to_remove release];
645
+
646
+ if (size_all > 0) {
647
+ GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
648
+ GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
649
+ }
650
+
651
+ free(mem_pool);
652
+ }
653
+
654
+ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
655
+ for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
656
+ ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
657
+
658
+ struct ggml_metal_heap * heap = ptr.data;
659
+ ggml_metal_heap_reset(heap);
660
+
661
+ // if the heap hasn't been used for a while, remove it
662
+ if (heap->n_unused >= 128) {
663
+ [mem_pool->heaps_to_remove addObject:@(i)];
664
+ }
665
+ }
666
+
667
+ if (mem_pool->heaps_to_remove.count > 0) {
668
+ // remove in reverse order
669
+ for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
670
+ NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
671
+ ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
672
+
673
+ struct ggml_metal_heap * heap = ptr.data;
674
+ ggml_metal_heap_free(heap);
675
+
676
+ [mem_pool->heaps removeObjectAtIndex:index];
677
+ [ptr release];
678
+
679
+ if (i == 0) {
680
+ break;
681
+ }
682
+ }
683
+
684
+ [mem_pool->heaps_to_remove removeAllObjects];
685
+ }
686
+ }
687
+
688
+ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
689
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
690
+ ptr.data->offs = 0;
691
+ }
692
+ }
693
+
694
+ static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
695
+ const size_t alignment = 256;
696
+
697
+ const size_t size_aligned = GGML_PAD(size, alignment);
698
+
699
+ // try one of the existing heaps
700
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
701
+ struct ggml_metal_heap * heap = ptr.data;
702
+ if (heap->offs + size_aligned <= [heap->obj size]) {
703
+ // if this is the first buffer in the heap for the current command buffer, tell the OS that
704
+ // it cannot free the memory used by the heap
705
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
706
+ if ([heap->bufs count] == 0) {
707
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
708
+ }
709
+
710
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
711
+ if (buf == nil) {
712
+ GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
713
+ return nil;
714
+ }
715
+
716
+ heap->n_alloc++;
717
+ heap->offs += size_aligned;
718
+
719
+ [heap->bufs addObject:buf];
720
+
721
+ return buf;
722
+ }
723
+ }
724
+
725
+ // create a new heap that can fit this buffer
726
+ ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
727
+
728
+ struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
729
+ if (heap == NULL) {
730
+ GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
731
+ return NULL;
732
+ }
733
+
734
+ //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
735
+
736
+ heap_ptr.data = heap;
737
+ ggml_metal_heap_reset(heap);
738
+
739
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
740
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
741
+ if (buf == nil) {
742
+ GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
743
+ return NULL;
744
+ }
745
+
746
+ heap->n_alloc++;
747
+ heap->offs += size_aligned;
748
+
749
+ [heap->bufs addObject:buf];
750
+
751
+ [mem_pool->heaps addObject:heap_ptr];
752
+ mem_pool->n_heaps++;
753
+
754
+ return buf;
755
+ }
756
+
757
+ struct ggml_metal_command_buffer {
758
+ id<MTLCommandBuffer> obj;
759
+
760
+ // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
761
+ struct ggml_metal_mem_pool * mem_pool;
762
+ };
763
+
403
764
  struct ggml_backend_metal_context {
765
+ id<MTLDevice> device;
404
766
  id<MTLCommandQueue> queue;
405
767
 
406
768
  dispatch_queue_t d_queue;
@@ -425,7 +787,7 @@ struct ggml_backend_metal_context {
425
787
  void (^encode_async)(size_t ith);
426
788
 
427
789
  // n_cb command buffers + 1 used by the main thread
428
- id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
790
+ struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
429
791
 
430
792
  // abort ggml_metal_graph_compute if callback returns true
431
793
  ggml_abort_callback abort_callback;
@@ -437,11 +799,13 @@ struct ggml_backend_metal_context {
437
799
  // for now it is easier to work in a separate file
438
800
  // static NSString * const msl_library_source = @"see metal.metal";
439
801
 
802
+ #if !GGML_METAL_EMBED_LIBRARY
440
803
  // Here to assist with NSBundle Path Hack
441
804
  @interface GGMLMetalClass : NSObject
442
805
  @end
443
806
  @implementation GGMLMetalClass
444
807
  @end
808
+ #endif
445
809
 
446
810
  static void * ggml_metal_host_malloc(size_t n) {
447
811
  void * data = NULL;
@@ -463,159 +827,176 @@ static void * ggml_metal_host_malloc(size_t n) {
463
827
  return data;
464
828
  }
465
829
 
466
- static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
467
- GGML_LOG_INFO("%s: allocating\n", __func__);
468
-
469
- #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
470
- // Show all the Metal device instances in the system
471
- NSArray * devices = MTLCopyAllDevices();
472
- for (id<MTLDevice> device in devices) {
473
- GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
474
- }
475
- [devices release]; // since it was created by a *Copy* C method
476
- #endif
477
-
478
- // init context
479
- struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
480
- struct ggml_backend_metal_device_context * ctx_dev = dev->context;
830
+ // load library
831
+ //
832
+ // - first check if the library is embedded
833
+ // - then check if the library is in the bundle
834
+ // - if not found, load the source and compile it
835
+ // - if that fails, return NULL
836
+ static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfloat) {
837
+ id<MTLLibrary> metal_library = nil;
838
+ NSError * error = nil;
839
+ NSString * src = nil;
481
840
 
482
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
483
- GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
841
+ #if GGML_METAL_EMBED_LIBRARY
842
+ GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
484
843
 
485
- ctx->queue = [device newCommandQueue];
486
- ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
844
+ extern const char ggml_metallib_start[];
845
+ extern const char ggml_metallib_end[];
487
846
 
488
- id<MTLLibrary> metal_library;
847
+ src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
489
848
 
490
- // load library
491
- //
492
- // - first check if the library is embedded
493
- // - then check if the library is in the bundle
494
- // - if not found, load the source and compile it
495
- // - if that fails, return NULL
496
- {
497
- NSBundle * bundle = nil;
498
- #ifdef SWIFT_PACKAGE
499
- bundle = SWIFTPM_MODULE_BUNDLE;
500
849
  #else
501
- bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
502
- #endif
503
-
504
- NSError * error = nil;
505
850
 
506
- #if GGML_METAL_EMBED_LIBRARY
507
- const bool try_metallib = false;
851
+ #ifdef SWIFT_PACKAGE
852
+ NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
508
853
  #else
509
- const bool try_metallib = true;
854
+ NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
510
855
  #endif
511
856
 
512
- NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
513
- if (path_lib == nil) {
514
- // Try to find the resource in the directory where the current binary located.
515
- NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
516
- NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
517
- NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
518
- if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
519
- GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
520
- NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
521
- if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
522
- // Optionally, if this is a symlink, try to resolve it.
523
- default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
524
- if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
525
- // It is a relative path, adding the binary directory as directory prefix.
526
- default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
527
- }
528
- if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
529
- // Link to the resource could not be resolved.
530
- default_metallib_path = nil;
531
- } else {
532
- GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
533
- }
857
+ NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
858
+ if (path_lib == nil) {
859
+ // Try to find the resource in the directory where the current binary located.
860
+ NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
861
+ NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
862
+ NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
863
+ if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
864
+ GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
865
+ NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
866
+ if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
867
+ // Optionally, if this is a symlink, try to resolve it.
868
+ default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
869
+ if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
870
+ // It is a relative path, adding the binary directory as directory prefix.
871
+ default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
872
+ }
873
+ if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
874
+ // Link to the resource could not be resolved.
875
+ default_metallib_path = nil;
876
+ } else {
877
+ GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
534
878
  }
535
- } else {
536
- // The resource couldn't be found in the binary's directory.
537
- default_metallib_path = nil;
538
879
  }
539
- path_lib = default_metallib_path;
880
+ } else {
881
+ // The resource couldn't be found in the binary's directory.
882
+ default_metallib_path = nil;
540
883
  }
884
+ path_lib = default_metallib_path;
885
+ }
541
886
 
542
- if (try_metallib && path_lib != nil) {
543
- // pre-compiled library found
544
- NSURL * libURL = [NSURL fileURLWithPath:path_lib];
545
- GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
887
+ if (path_lib != nil) {
888
+ // pre-compiled library found
889
+ NSURL * libURL = [NSURL fileURLWithPath:path_lib];
890
+ GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
546
891
 
547
- metal_library = [device newLibraryWithURL:libURL error:&error];
548
- if (error) {
549
- GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
550
- return NULL;
551
- }
892
+ metal_library = [device newLibraryWithURL:libURL error:&error];
893
+ if (error) {
894
+ GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
895
+ return NULL;
896
+ }
897
+ } else {
898
+ GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
899
+
900
+ NSString * path_source;
901
+ NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
902
+
903
+ GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
904
+
905
+ if (path_resource) {
906
+ path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
552
907
  } else {
553
- #if GGML_METAL_EMBED_LIBRARY
554
- GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
908
+ path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
909
+ }
555
910
 
556
- extern const char ggml_metallib_start[];
557
- extern const char ggml_metallib_end[];
911
+ if (path_source == nil) {
912
+ GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
913
+ path_source = @"ggml-metal.metal";
914
+ }
558
915
 
559
- NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
560
- #else
561
- GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
916
+ GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
562
917
 
563
- NSString * path_source;
564
- NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
918
+ src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
919
+ if (error) {
920
+ GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
921
+ return NULL;
922
+ }
923
+ }
924
+ #endif
565
925
 
566
- GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
926
+ if (!metal_library) {
927
+ @autoreleasepool {
928
+ // dictionary of preprocessor macros
929
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
567
930
 
568
- if (path_resource) {
569
- path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
570
- } else {
571
- path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
931
+ if (use_bfloat) {
932
+ [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
572
933
  }
573
934
 
574
- if (path_source == nil) {
575
- GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
576
- path_source = @"ggml-metal.metal";
577
- }
935
+ #if GGML_METAL_EMBED_LIBRARY
936
+ [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
937
+ #endif
938
+
939
+ MTLCompileOptions * options = [MTLCompileOptions new];
940
+ options.preprocessorMacros = prep;
578
941
 
579
- GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
942
+ //[options setFastMathEnabled:false];
580
943
 
581
- NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
944
+ metal_library = [device newLibraryWithSource:src options:options error:&error];
582
945
  if (error) {
583
946
  GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
584
947
  return NULL;
585
948
  }
949
+
950
+ #if !__has_feature(objc_arc)
951
+ [options release];
952
+ #endif
953
+ }
954
+ }
955
+
956
+ #if GGML_METAL_EMBED_LIBRARY
957
+ [src release];
586
958
  #endif // GGML_METAL_EMBED_LIBRARY
587
959
 
588
- @autoreleasepool {
589
- // dictionary of preprocessor macros
590
- NSMutableDictionary * prep = [NSMutableDictionary dictionary];
960
+ return metal_library;
961
+ }
591
962
 
592
- if (ctx_dev->use_bfloat) {
593
- [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
594
- }
963
+ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
964
+ GGML_LOG_INFO("%s: allocating\n", __func__);
595
965
 
596
- #if GGML_METAL_EMBED_LIBRARY
597
- [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
966
+ #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
967
+ // Show all the Metal device instances in the system
968
+ NSArray * devices = MTLCopyAllDevices();
969
+ for (id<MTLDevice> device in devices) {
970
+ GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
971
+ }
972
+ [devices release]; // since it was created by a *Copy* C method
598
973
  #endif
599
974
 
600
- MTLCompileOptions * options = [MTLCompileOptions new];
601
- options.preprocessorMacros = prep;
975
+ // init context
976
+ struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
977
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
602
978
 
603
- //[options setFastMathEnabled:false];
979
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
604
980
 
605
- metal_library = [device newLibraryWithSource:src options:options error:&error];
606
- if (error) {
607
- GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
608
- return NULL;
609
- }
981
+ GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
610
982
 
611
- #if !__has_feature(objc_arc)
612
- [options release];
613
- #endif
614
- }
615
- #if GGML_METAL_EMBED_LIBRARY
616
- [src release];
617
- #endif // GGML_METAL_EMBED_LIBRARY
618
- }
983
+ ctx->device = device;
984
+ ctx->queue = [device newCommandQueue];
985
+ if (ctx->queue == nil) {
986
+ GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
987
+ return NULL;
988
+ }
989
+
990
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
991
+
992
+ // load library
993
+ if (ctx_dev->mtl_library == nil) {
994
+ ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
995
+ }
996
+ id<MTLLibrary> metal_library = ctx_dev->mtl_library;
997
+ if (metal_library == nil) {
998
+ GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
999
+ return NULL;
619
1000
  }
620
1001
 
621
1002
  // print MTL GPU family:
@@ -649,6 +1030,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
649
1030
 
650
1031
  GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
651
1032
  GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
1033
+ GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
652
1034
  GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
653
1035
  GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
654
1036
  GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
@@ -660,7 +1042,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
660
1042
  ctx->gf = nil;
661
1043
  ctx->encode_async = nil;
662
1044
  for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
663
- ctx->command_buffers[i] = nil;
1045
+ ctx->cmd_bufs[i].obj = nil;
1046
+
1047
+ ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
1048
+ ctx->cmd_bufs[i].mem_pool->device = device;
664
1049
  }
665
1050
 
666
1051
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -688,7 +1073,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
688
1073
  [metal_function release]; \
689
1074
  if (error) { \
690
1075
  GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
691
- [metal_library release]; \
692
1076
  return NULL; \
693
1077
  } \
694
1078
  } else { \
@@ -701,304 +1085,380 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
701
1085
 
702
1086
  // simd_sum and simd_max requires MTLGPUFamilyApple7
703
1087
 
704
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
705
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
706
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
707
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
708
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
709
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
710
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
711
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
712
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
713
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
714
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
715
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
716
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
717
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
718
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
719
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
720
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
721
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
722
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
723
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
724
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
725
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
726
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
727
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
728
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
729
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
730
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
731
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
732
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
733
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
734
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
735
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
736
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
737
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
738
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
739
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
740
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
741
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
742
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
743
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
744
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
745
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
746
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
747
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
748
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
749
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
750
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
751
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
752
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
753
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
754
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
755
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
756
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
757
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
758
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
759
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
760
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
761
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
762
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
763
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
764
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
765
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
766
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
767
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
768
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
769
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
770
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
771
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
772
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
773
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
774
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
775
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
776
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
777
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
778
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
779
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
780
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
781
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
782
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
783
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
784
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
785
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
786
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
787
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
788
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
789
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
790
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
791
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
792
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
793
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
794
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
795
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
796
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
797
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
798
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
799
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
800
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
801
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
802
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
803
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
804
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
805
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
806
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
807
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
808
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
809
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
810
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
811
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
812
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
813
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
814
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
815
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
816
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
817
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
818
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
819
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
820
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
821
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
822
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
823
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
824
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
825
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
826
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
827
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
828
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
829
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
830
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
831
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
832
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
833
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
834
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
835
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
836
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
837
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
838
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
839
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
840
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
841
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
842
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
843
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
844
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
845
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
846
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
847
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
848
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
849
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
850
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
851
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
852
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
853
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
854
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
855
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
856
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
857
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
858
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
859
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
860
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
861
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
862
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
863
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
864
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
865
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
866
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
867
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
868
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
869
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
870
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
871
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
872
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
873
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
874
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
875
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
876
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
877
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
878
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
879
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
880
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
881
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
882
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
883
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
884
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
885
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
886
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
887
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
888
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
889
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
890
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
891
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
892
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
893
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
894
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
895
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
896
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
897
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
898
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
899
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
900
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
901
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
902
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
903
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
904
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
905
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
906
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
907
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
908
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
909
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
910
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
911
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
912
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
913
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
914
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
915
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
916
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
917
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
918
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
919
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
920
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
921
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
922
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
923
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
924
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
925
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
926
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
927
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
928
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
929
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
930
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
931
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
932
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
933
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
934
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
935
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
936
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
937
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
938
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
939
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
940
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
941
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
942
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
943
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
944
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
945
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
946
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
947
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
948
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
949
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
950
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
951
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
952
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
953
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
954
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
955
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
956
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
957
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
958
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
959
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
960
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
961
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
962
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
963
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
964
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
965
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
966
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
967
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
968
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
969
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
970
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
971
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
972
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
973
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
974
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
975
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
976
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
977
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
978
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
979
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
980
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
981
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
982
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
983
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
984
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
985
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
986
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
987
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
988
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
989
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
990
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
991
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
992
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
993
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
994
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
995
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
996
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
997
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
1088
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1089
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1090
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1091
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
1092
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1093
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
1094
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1095
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
1096
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1097
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1098
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
1099
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
1100
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
1101
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
1102
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
1103
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
1104
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
1105
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
1106
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
1107
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
1108
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
1109
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
1110
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
1111
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
1112
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
1113
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
1114
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
1115
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
1116
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
1117
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
1118
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
1119
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
1120
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
1121
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
1122
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
1123
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
1124
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
1125
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
1126
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
1127
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
1128
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
1129
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
1130
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
1131
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
1132
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
1133
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
1134
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
1135
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
1136
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
1137
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
1138
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
1139
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
1140
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
1141
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1142
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1143
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1144
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1145
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1146
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1147
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1148
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1149
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1150
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1151
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1152
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1153
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1154
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1155
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1156
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1157
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1158
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1159
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1160
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
1161
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
1162
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
1163
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
1164
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
1165
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
1166
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
1167
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
1168
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
1169
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
1170
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
1171
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
1172
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
1173
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
1174
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
1175
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
1176
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
1177
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
1178
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
1179
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
1180
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
1182
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
1190
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
1191
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
1192
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
1193
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
1194
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
1195
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
1196
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
1197
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
1198
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
1199
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
1200
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
1201
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
1202
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
1203
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
1204
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
1205
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
1206
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
1207
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
1208
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
1209
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
1210
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
1211
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
1212
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
1213
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
1214
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
1215
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
1216
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
1217
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
1218
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
1219
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
1220
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
1221
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
1222
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
1223
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
1224
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
1225
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
1226
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
1227
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
1228
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
1229
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
1230
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
1231
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
1232
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
1233
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
1234
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
1235
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
1236
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
1237
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
1238
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
1239
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
1240
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
1241
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
1242
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
1243
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
1244
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
1245
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
1246
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
1247
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
1248
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
1249
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
1250
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
1251
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
1252
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
1253
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
1254
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
1255
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
1256
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
1257
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
1258
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
1259
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
1260
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
1261
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
1262
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
1263
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
1264
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
1265
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
1266
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1267
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1268
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
1269
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
1270
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
1271
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
1272
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
1273
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
1274
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
1275
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
1276
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
1277
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
1278
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
1279
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
1280
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
1281
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
1282
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
1283
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
1284
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
1285
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
1286
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
1287
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
1288
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
1289
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
1290
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
1291
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
1292
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1293
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
1294
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
1295
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
1296
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
1297
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
1298
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
1299
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
1300
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
1301
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
1302
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
1303
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
1304
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
1305
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
1306
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
1307
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
1308
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
1309
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
1310
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
1311
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
1312
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
1313
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
1314
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
1315
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
1316
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
1317
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
1318
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
1319
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
1320
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
1321
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
1322
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
1323
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
1324
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
1325
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
1326
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
1327
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
1328
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
1329
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
1330
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
1331
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
1332
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
1333
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
1334
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
1335
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
1336
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
1337
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
1338
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
1339
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
1340
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
1341
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
1342
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
1343
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
1344
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
1345
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
1346
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
1347
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
1348
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
1349
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
1350
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
1351
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
1352
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
1353
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
1354
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
1355
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
1356
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
1357
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
1358
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
1359
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
1360
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
1361
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
1362
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
1363
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
1364
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
1365
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
1366
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
1367
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
1368
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
1369
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
1370
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
1371
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
1372
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
1373
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1374
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1375
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1376
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
1377
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
1378
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
1379
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
1380
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
1381
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
1382
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
1383
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
1384
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
1385
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
1386
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
1387
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
1388
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
1389
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
1390
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
1391
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
1392
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
1393
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
1394
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
1395
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
1396
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
1397
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction);
1398
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat);
1399
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction);
1400
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction);
1401
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction);
1402
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction);
1403
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction);
1404
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction);
1405
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat);
1406
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction);
1407
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction);
1408
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction);
1409
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction);
1410
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction);
1411
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
1412
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
1413
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
1414
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
1415
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
1416
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
1417
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
1418
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
1419
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
1420
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
1421
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
1422
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
1423
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
1424
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
1425
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
1426
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
1427
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
1428
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
1429
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
1430
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
1431
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
1432
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
1433
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
1434
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
1435
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
1436
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
1437
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
1438
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
1439
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1440
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1441
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
1442
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1443
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
1444
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1445
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
1446
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1447
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
1448
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1449
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
1450
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
1451
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
1452
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
1453
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1454
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1457
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1458
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1459
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
998
1460
  }
999
1461
 
1000
- [metal_library release];
1001
-
1002
1462
  return ctx;
1003
1463
  }
1004
1464
 
@@ -1013,6 +1473,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
1013
1473
 
1014
1474
  [ctx->queue release];
1015
1475
 
1476
+ for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1477
+ // ctx->cmd_bufs[i].obj is auto released
1478
+
1479
+ ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
1480
+ }
1481
+
1016
1482
  dispatch_release(ctx->d_queue);
1017
1483
 
1018
1484
  free(ctx);
@@ -1035,8 +1501,70 @@ struct ggml_backend_metal_buffer_context {
1035
1501
  // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
1036
1502
  int n_buffers;
1037
1503
  struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
1504
+
1505
+ // optional MTLResidencySet
1506
+ id rset;
1038
1507
  };
1039
1508
 
1509
+ // rset init
1510
+ static bool ggml_backend_metal_buffer_rset_init(
1511
+ struct ggml_backend_metal_buffer_context * ctx,
1512
+ struct ggml_backend_metal_device_context * ctx_dev,
1513
+ id<MTLDevice> device) {
1514
+ ctx->rset = nil;
1515
+
1516
+ if (!ctx_dev->has_residency_sets) {
1517
+ return true;
1518
+ }
1519
+
1520
+ #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
1521
+ if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1522
+ MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
1523
+ desc.label = @"ggml_backend_metal";
1524
+ desc.initialCapacity = ctx->n_buffers;
1525
+
1526
+ NSError * error;
1527
+ ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
1528
+ if (error) {
1529
+ GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
1530
+ [desc release];
1531
+ return false;
1532
+ }
1533
+
1534
+ [desc release];
1535
+
1536
+ for (int i = 0; i < ctx->n_buffers; i++) {
1537
+ [ctx->rset addAllocation:ctx->buffers[i].metal];
1538
+ }
1539
+
1540
+ [ctx->rset commit];
1541
+ [ctx->rset requestResidency];
1542
+
1543
+ return true;
1544
+ }
1545
+ #else
1546
+ GGML_UNUSED(ctx_dev);
1547
+ GGML_UNUSED(device);
1548
+ #endif
1549
+
1550
+ return true;
1551
+ }
1552
+
1553
+ // rset free
1554
+ static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
1555
+ #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
1556
+ if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1557
+ if (ctx->rset) {
1558
+ [ctx->rset endResidency];
1559
+ [ctx->rset removeAllAllocations];
1560
+ [ctx->rset release];
1561
+ }
1562
+ }
1563
+ #else
1564
+ GGML_UNUSED(ctx);
1565
+ #endif
1566
+ }
1567
+
1040
1568
  // finds the Metal buffer that contains the tensor data on the GPU device
1041
1569
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
1042
1570
  // Metal buffer based on the host memory pointer
@@ -1089,10 +1617,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1089
1617
  case GGML_UNARY_OP_RELU:
1090
1618
  case GGML_UNARY_OP_SIGMOID:
1091
1619
  case GGML_UNARY_OP_GELU:
1620
+ case GGML_UNARY_OP_GELU_ERF:
1092
1621
  case GGML_UNARY_OP_GELU_QUICK:
1093
1622
  case GGML_UNARY_OP_SILU:
1094
1623
  case GGML_UNARY_OP_ELU:
1095
- return ggml_is_contiguous(op->src[0]);
1624
+ case GGML_UNARY_OP_NEG:
1625
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1096
1626
  default:
1097
1627
  return false;
1098
1628
  }
@@ -1102,61 +1632,73 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1102
1632
  case GGML_OP_TRANSPOSE:
1103
1633
  case GGML_OP_PERMUTE:
1104
1634
  case GGML_OP_CONCAT:
1635
+ return true;
1105
1636
  case GGML_OP_ADD:
1106
1637
  case GGML_OP_SUB:
1107
- case GGML_OP_ACC:
1108
1638
  case GGML_OP_MUL:
1109
1639
  case GGML_OP_DIV:
1640
+ return op->src[0]->type == GGML_TYPE_F32;
1641
+ case GGML_OP_ACC:
1110
1642
  case GGML_OP_REPEAT:
1111
1643
  case GGML_OP_SCALE:
1112
- case GGML_OP_CLAMP:
1113
1644
  case GGML_OP_CONV_TRANSPOSE_1D:
1114
1645
  return true;
1646
+ case GGML_OP_CLAMP:
1647
+ return op->src[0]->type == GGML_TYPE_F32;
1115
1648
  case GGML_OP_SQR:
1116
1649
  case GGML_OP_SQRT:
1117
1650
  case GGML_OP_SIN:
1118
1651
  case GGML_OP_COS:
1119
- return ggml_is_contiguous(op->src[0]);
1652
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1653
+ case GGML_OP_LOG:
1654
+ return false; // TODO: implement
1120
1655
  case GGML_OP_SUM_ROWS:
1121
1656
  case GGML_OP_SOFT_MAX:
1122
1657
  case GGML_OP_GROUP_NORM:
1123
- return has_simdgroup_reduction;
1658
+ return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1124
1659
  case GGML_OP_RMS_NORM:
1125
- return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
1660
+ case GGML_OP_L2_NORM:
1661
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1126
1662
  case GGML_OP_ARGMAX:
1127
- case GGML_OP_NORM:
1128
1663
  return true;
1664
+ case GGML_OP_NORM:
1665
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1129
1666
  case GGML_OP_ROPE:
1130
- {
1131
- const int mode = ((const int32_t *) op->op_params)[2];
1132
- if (mode & GGML_ROPE_TYPE_MROPE) {
1133
- return false;
1134
- }
1135
- if (mode & GGML_ROPE_TYPE_VISION) {
1136
- return false;
1137
- }
1138
- return true;
1139
- }
1667
+ return true;
1140
1668
  case GGML_OP_IM2COL:
1141
1669
  return op->src[0]->type == GGML_TYPE_F16;
1142
1670
  case GGML_OP_POOL_1D:
1143
1671
  return false;
1144
- case GGML_OP_POOL_2D:
1145
1672
  case GGML_OP_UPSCALE:
1673
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
1674
+ case GGML_OP_POOL_2D:
1146
1675
  case GGML_OP_PAD:
1147
1676
  case GGML_OP_PAD_REFLECT_1D:
1148
- case GGML_OP_ARANGE:
1149
1677
  case GGML_OP_TIMESTEP_EMBEDDING:
1150
1678
  case GGML_OP_ARGSORT:
1151
1679
  case GGML_OP_LEAKY_RELU:
1680
+ return op->src[0]->type == GGML_TYPE_F32;
1681
+ case GGML_OP_ARANGE:
1152
1682
  return true;
1153
1683
  case GGML_OP_FLASH_ATTN_EXT:
1684
+ if (op->src[0]->ne[0] == 32) {
1685
+ // head size == 32 (e.g. bert-bge-small)
1686
+ // TODO: not sure if it is worth adding kernels for this size
1687
+ return false;
1688
+ }
1689
+ if (op->src[0]->ne[0] == 576) {
1690
+ // DeepSeek sizes
1691
+ // TODO: disabled for now, until optmized
1692
+ return false;
1693
+ }
1154
1694
  if (op->src[1]->type != op->src[2]->type) {
1155
1695
  return false;
1156
1696
  }
1157
1697
  return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
1158
1698
  case GGML_OP_SSM_CONV:
1159
1699
  case GGML_OP_SSM_SCAN:
1700
+ case GGML_OP_RWKV_WKV6:
1701
+ case GGML_OP_RWKV_WKV7:
1160
1702
  return true;
1161
1703
  case GGML_OP_MUL_MAT:
1162
1704
  case GGML_OP_MUL_MAT_ID:
@@ -1198,6 +1740,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1198
1740
  default:
1199
1741
  return false;
1200
1742
  }
1743
+ case GGML_TYPE_Q4_0:
1744
+ case GGML_TYPE_Q4_1:
1745
+ case GGML_TYPE_Q5_0:
1746
+ case GGML_TYPE_Q5_1:
1747
+ case GGML_TYPE_Q8_0:
1748
+ switch (op->type) {
1749
+ case GGML_TYPE_F32:
1750
+ case GGML_TYPE_F16:
1751
+ return true;
1752
+ default:
1753
+ return false;
1754
+ }
1201
1755
  default:
1202
1756
  return false;
1203
1757
  };
@@ -1222,10 +1776,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1222
1776
  }
1223
1777
  }
1224
1778
 
1225
- static void ggml_metal_encode_node(
1779
+ static bool ggml_metal_encode_node(
1226
1780
  ggml_backend_t backend,
1227
1781
  int idx,
1228
- id<MTLComputeCommandEncoder> encoder) {
1782
+ id<MTLComputeCommandEncoder> encoder,
1783
+ struct ggml_metal_mem_pool * mem_pool) {
1229
1784
  struct ggml_backend_metal_context * ctx = backend->context;
1230
1785
  struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
1231
1786
 
@@ -1241,7 +1796,7 @@ static void ggml_metal_encode_node(
1241
1796
  struct ggml_tensor * dst = node;
1242
1797
 
1243
1798
  if (ggml_is_empty(dst)) {
1244
- return;
1799
+ return true;
1245
1800
  }
1246
1801
 
1247
1802
  switch (dst->op) {
@@ -1252,7 +1807,7 @@ static void ggml_metal_encode_node(
1252
1807
  case GGML_OP_PERMUTE:
1253
1808
  {
1254
1809
  // noop -> next node
1255
- } return;
1810
+ } return true;
1256
1811
  default:
1257
1812
  {
1258
1813
  } break;
@@ -1263,6 +1818,8 @@ static void ggml_metal_encode_node(
1263
1818
  GGML_ABORT("unsupported op");
1264
1819
  }
1265
1820
 
1821
+ ggml_metal_mem_pool_clear(mem_pool);
1822
+
1266
1823
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
1267
1824
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
1268
1825
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -1699,6 +2256,25 @@ static void ggml_metal_encode_node(
1699
2256
 
1700
2257
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1701
2258
  } break;
2259
+ case GGML_UNARY_OP_GELU_ERF:
2260
+ {
2261
+ int64_t n = ggml_nelements(dst);
2262
+
2263
+ id<MTLComputePipelineState> pipeline = nil;
2264
+
2265
+ if (n % 4 == 0) {
2266
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
2267
+ n /= 4;
2268
+ } else {
2269
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
2270
+ }
2271
+
2272
+ [encoder setComputePipelineState:pipeline];
2273
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2274
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2275
+
2276
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2277
+ } break;
1702
2278
  case GGML_UNARY_OP_GELU_QUICK:
1703
2279
  {
1704
2280
  int64_t n = ggml_nelements(dst);
@@ -1749,6 +2325,18 @@ static void ggml_metal_encode_node(
1749
2325
 
1750
2326
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1751
2327
  } break;
2328
+ case GGML_UNARY_OP_NEG:
2329
+ {
2330
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
2331
+
2332
+ [encoder setComputePipelineState:pipeline];
2333
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2334
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2335
+
2336
+ const int64_t n = ggml_nelements(dst);
2337
+
2338
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2339
+ } break;
1752
2340
  default:
1753
2341
  {
1754
2342
  GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
@@ -1817,34 +2405,38 @@ static void ggml_metal_encode_node(
1817
2405
 
1818
2406
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1819
2407
 
1820
- // TODO: add ggml_metal_kargs struct
2408
+
2409
+ ggml_metal_kargs_sum_rows args = {
2410
+ /*.ne00 =*/ ne00,
2411
+ /*.ne01 =*/ ne01,
2412
+ /*.ne02 =*/ ne02,
2413
+ /*.ne03 =*/ ne03,
2414
+ /*.nb00 =*/ nb00,
2415
+ /*.nb01 =*/ nb01,
2416
+ /*.nb02 =*/ nb02,
2417
+ /*.nb03 =*/ nb03,
2418
+ /*.ne10 =*/ ne10,
2419
+ /*.ne11 =*/ ne11,
2420
+ /*.ne12 =*/ ne12,
2421
+ /*.ne13 =*/ ne13,
2422
+ /*.nb10 =*/ nb10,
2423
+ /*.nb11 =*/ nb11,
2424
+ /*.nb12 =*/ nb12,
2425
+ /*.nb13 =*/ nb13,
2426
+ /*.ne0 =*/ ne0,
2427
+ /*.ne1 =*/ ne1,
2428
+ /*.ne2 =*/ ne2,
2429
+ /*.ne3 =*/ ne3,
2430
+ /*.nb0 =*/ nb0,
2431
+ /*.nb1 =*/ nb1,
2432
+ /*.nb2 =*/ nb2,
2433
+ /*.nb3 =*/ nb3,
2434
+ };
2435
+
1821
2436
  [encoder setComputePipelineState:pipeline];
1822
2437
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1823
2438
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1824
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1825
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1826
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1827
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1828
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1829
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1830
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1831
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1832
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1833
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1834
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1835
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1836
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1837
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1838
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1839
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1840
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1841
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1842
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1843
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1844
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1845
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1846
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1847
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
2439
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
1848
2440
 
1849
2441
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1850
2442
  } break;
@@ -1893,24 +2485,76 @@ static void ggml_metal_encode_node(
1893
2485
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1894
2486
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1895
2487
 
1896
- // TODO: add ggml_metal_kargs struct
1897
- // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
2488
+ // use this branch to test the ggml_metal_mem_pool functionality
2489
+ #if 0
2490
+ // cpy to tmp buffer in MTLHeap
2491
+
2492
+ id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2493
+ if (!h_src0) {
2494
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2495
+ return false;
2496
+ }
2497
+
2498
+ offs_src0 = 0;
2499
+
2500
+ ggml_metal_kargs_cpy args_cpy = {
2501
+ /*.ne00 =*/ ne00,
2502
+ /*.ne01 =*/ ne01,
2503
+ /*.ne02 =*/ ne02,
2504
+ /*.ne03 =*/ ne03,
2505
+ /*.nb00 =*/ nb00,
2506
+ /*.nb01 =*/ nb01,
2507
+ /*.nb02 =*/ nb02,
2508
+ /*.nb03 =*/ nb03,
2509
+ /*.ne0 =*/ ne00,
2510
+ /*.ne1 =*/ ne01,
2511
+ /*.ne2 =*/ ne02,
2512
+ /*.ne3 =*/ ne03,
2513
+ /*.nb0 =*/ nb00,
2514
+ /*.nb1 =*/ nb01,
2515
+ /*.nb2 =*/ nb02,
2516
+ /*.nb3 =*/ nb03,
2517
+ };
2518
+
2519
+ if (src0->type == GGML_TYPE_F16) {
2520
+ [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2521
+ } else {
2522
+ [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2523
+ }
2524
+ [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
2525
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2526
+ [encoder setBuffer:h_src0 offset:0 atIndex:2];
2527
+
2528
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2529
+ int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
2530
+
2531
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
2532
+
2533
+ #else
2534
+ id<MTLBuffer> h_src0 = id_src0;
2535
+ #endif
2536
+ // softmax
2537
+
2538
+ ggml_metal_kargs_soft_max args = {
2539
+ /*.ne00 =*/ ne00,
2540
+ /*.ne01 =*/ ne01,
2541
+ /*.ne02 =*/ ne02,
2542
+ /*.scale =*/ scale,
2543
+ /*.max_bias =*/ max_bias,
2544
+ /*.m0 =*/ m0,
2545
+ /*.m1 =*/ m1,
2546
+ /*.n_head_log2 =*/ n_head_log2,
2547
+ };
2548
+
1898
2549
  [encoder setComputePipelineState:pipeline];
1899
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2550
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
1900
2551
  if (id_src1) {
1901
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2552
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1902
2553
  } else {
1903
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2554
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
1904
2555
  }
1905
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1906
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1907
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1908
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1909
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1910
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1911
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1912
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1913
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
2556
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2557
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
1914
2558
 
1915
2559
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1916
2560
 
@@ -1928,13 +2572,16 @@ static void ggml_metal_encode_node(
1928
2572
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1929
2573
  }
1930
2574
 
1931
- // TODO: add ggml_metal_kargs struct
2575
+ ggml_metal_kargs_diag_mask_inf args = {
2576
+ /*.ne00 =*/ ne00,
2577
+ /*.ne01 =*/ ne01,
2578
+ /*.n_past =*/ n_past,
2579
+ };
2580
+
1932
2581
  [encoder setComputePipelineState:pipeline];
1933
2582
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1934
2583
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1935
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1936
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1937
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
2584
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
1938
2585
 
1939
2586
  if (ne00%8 == 0) {
1940
2587
  [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -1953,27 +2600,30 @@ static void ggml_metal_encode_node(
1953
2600
 
1954
2601
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
1955
2602
 
1956
- // TODO: add ggml_metal_kargs struct
2603
+ ggml_metal_kargs_ssm_conv args = {
2604
+ /*.ne00 =*/ ne00,
2605
+ /*.ne01 =*/ ne01,
2606
+ /*.ne02 =*/ ne02,
2607
+ /*.nb00 =*/ nb00,
2608
+ /*.nb01 =*/ nb01,
2609
+ /*.nb02 =*/ nb02,
2610
+ /*.ne10 =*/ ne10,
2611
+ /*.ne11 =*/ ne11,
2612
+ /*.nb10 =*/ nb10,
2613
+ /*.nb11 =*/ nb11,
2614
+ /*.ne0 =*/ ne0,
2615
+ /*.ne1 =*/ ne1,
2616
+ /*.ne2 =*/ ne2,
2617
+ /*.nb0 =*/ nb0,
2618
+ /*.nb1 =*/ nb1,
2619
+ /*.nb2 =*/ nb2,
2620
+ };
2621
+
1957
2622
  [encoder setComputePipelineState:pipeline];
1958
2623
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1959
2624
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1960
2625
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1961
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1962
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1963
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1964
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1965
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1966
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1967
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1968
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1969
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1970
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1971
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1972
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1973
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
1974
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
1975
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
1976
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
2626
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
1977
2627
 
1978
2628
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1979
2629
  } break;
@@ -2024,7 +2674,31 @@ static void ggml_metal_encode_node(
2024
2674
 
2025
2675
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2026
2676
 
2027
- // TODO: add ggml_metal_kargs struct
2677
+ ggml_metal_kargs_ssm_scan args = {
2678
+ /*.d_state =*/ d_state,
2679
+ /*.d_inner =*/ d_inner,
2680
+ /*.n_seq_tokens =*/ n_seq_tokens,
2681
+ /*.n_seqs =*/ n_seqs,
2682
+ /*.nb00 =*/ nb00,
2683
+ /*.nb01 =*/ nb01,
2684
+ /*.nb02 =*/ nb02,
2685
+ /*.nb10 =*/ nb10,
2686
+ /*.nb11 =*/ nb11,
2687
+ /*.nb12 =*/ nb12,
2688
+ /*.nb13 =*/ nb13,
2689
+ /*.nb20 =*/ nb20,
2690
+ /*.nb21 =*/ nb21,
2691
+ /*.nb22 =*/ nb22,
2692
+ /*.nb30 =*/ nb30,
2693
+ /*.nb31 =*/ nb31,
2694
+ /*.nb40 =*/ nb40,
2695
+ /*.nb41 =*/ nb41,
2696
+ /*.nb42 =*/ nb42,
2697
+ /*.nb50 =*/ nb50,
2698
+ /*.nb51 =*/ nb51,
2699
+ /*.nb52 =*/ nb52,
2700
+ };
2701
+
2028
2702
  [encoder setComputePipelineState:pipeline];
2029
2703
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2030
2704
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2033,33 +2707,87 @@ static void ggml_metal_encode_node(
2033
2707
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2034
2708
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2035
2709
  [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2036
-
2037
- [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
2038
- [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
2039
- [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
2040
- [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
2041
-
2042
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
2043
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
2044
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
2045
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
2046
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
2047
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
2048
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
2049
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
2050
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
2051
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
2052
- [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
2053
- [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
2054
- [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
2055
- [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
2056
- [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
2057
- [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
2058
- [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
2059
- [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
2710
+ [encoder setBytes:&args length:sizeof(args) atIndex:7];
2060
2711
 
2061
2712
  [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2062
2713
  } break;
2714
+ case GGML_OP_RWKV_WKV6:
2715
+ {
2716
+ const int64_t B = dst->src[5]->ne[1];
2717
+ const int64_t T = dst->src[0]->ne[2];
2718
+ const int64_t C = dst->ne[0];
2719
+ const int64_t H = dst->src[0]->ne[1];
2720
+
2721
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
2722
+ GGML_ASSERT(C % H == 0);
2723
+ GGML_ASSERT(C / H == 64);
2724
+
2725
+ size_t offs_src3 = 0;
2726
+ size_t offs_src4 = 0;
2727
+ size_t offs_src5 = 0;
2728
+
2729
+ id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2730
+ id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2731
+ id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2732
+
2733
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
2734
+
2735
+ [encoder setComputePipelineState:pipeline];
2736
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2737
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2738
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2739
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2740
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2741
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2742
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2743
+
2744
+ [encoder setBytes:&B length:sizeof(B) atIndex:7];
2745
+ [encoder setBytes:&T length:sizeof(T) atIndex:8];
2746
+ [encoder setBytes:&C length:sizeof(C) atIndex:9];
2747
+ [encoder setBytes:&H length:sizeof(H) atIndex:10];
2748
+
2749
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2750
+ } break;
2751
+ case GGML_OP_RWKV_WKV7:
2752
+ {
2753
+ const int64_t B = dst->src[6]->ne[1];
2754
+ const int64_t T = dst->src[0]->ne[2];
2755
+ const int64_t C = dst->ne[0];
2756
+ const int64_t H = dst->src[0]->ne[1];
2757
+
2758
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
2759
+ GGML_ASSERT(C % H == 0);
2760
+ GGML_ASSERT(C / H == 64);
2761
+
2762
+ size_t offs_src3 = 0;
2763
+ size_t offs_src4 = 0;
2764
+ size_t offs_src5 = 0;
2765
+ size_t offs_src6 = 0;
2766
+
2767
+ id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2768
+ id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2769
+ id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2770
+ id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
2771
+
2772
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
2773
+
2774
+ [encoder setComputePipelineState:pipeline];
2775
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2776
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2777
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2778
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2779
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2780
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2781
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2782
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2783
+
2784
+ [encoder setBytes:&B length:sizeof(B) atIndex:8];
2785
+ [encoder setBytes:&T length:sizeof(T) atIndex:9];
2786
+ [encoder setBytes:&C length:sizeof(C) atIndex:10];
2787
+ [encoder setBytes:&H length:sizeof(H) atIndex:11];
2788
+
2789
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2790
+ } break;
2063
2791
  case GGML_OP_MUL_MAT:
2064
2792
  {
2065
2793
  GGML_ASSERT(ne00 == ne10);
@@ -2067,8 +2795,8 @@ static void ggml_metal_encode_node(
2067
2795
  GGML_ASSERT(ne12 % ne02 == 0);
2068
2796
  GGML_ASSERT(ne13 % ne03 == 0);
2069
2797
 
2070
- const uint r2 = ne12/ne02;
2071
- const uint r3 = ne13/ne03;
2798
+ const uint32_t r2 = ne12/ne02;
2799
+ const uint32_t r3 = ne13/ne03;
2072
2800
 
2073
2801
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2074
2802
  // to the matrix-vector kernel
@@ -2317,173 +3045,182 @@ static void ggml_metal_encode_node(
2317
3045
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2318
3046
 
2319
3047
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2320
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3048
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2321
3049
  } else {
2322
- int nth0 = 32;
2323
- int nth1 = 1;
2324
- int nrows = 1;
2325
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2326
-
2327
3050
  id<MTLComputePipelineState> pipeline = nil;
2328
3051
 
3052
+ int nsg = 0; // number of simdgroups
3053
+ int nr0 = 0; // number of src0 rows per simdgroup
3054
+ int nr1 = 1; // number of src1 rows per threadgroup
3055
+
3056
+ size_t smem = 0; // shared memory
3057
+
2329
3058
  // use custom matrix x vector kernel
2330
3059
  switch (src0t) {
2331
3060
  case GGML_TYPE_F32:
2332
3061
  {
2333
3062
  GGML_ASSERT(src1t == GGML_TYPE_F32);
3063
+ nsg = 1;
3064
+ nr0 = 1;
3065
+ nr1 = 4;
2334
3066
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
2335
- nrows = 4;
2336
3067
  } break;
2337
3068
  case GGML_TYPE_F16:
2338
3069
  {
2339
- nth0 = 32;
2340
- nth1 = 1;
3070
+ nsg = 1;
3071
+ nr0 = 1;
2341
3072
  if (src1t == GGML_TYPE_F32) {
2342
3073
  if (ne11 * ne12 < 4) {
2343
3074
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2344
3075
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2345
3076
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2346
- nrows = ne11;
3077
+ nr1 = ne11;
2347
3078
  } else {
2348
3079
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2349
- nrows = 4;
3080
+ nr1 = 4;
2350
3081
  }
2351
3082
  } else {
2352
3083
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2353
- nrows = 4;
3084
+ nr1 = 4;
2354
3085
  }
2355
3086
  } break;
2356
3087
  case GGML_TYPE_BF16:
2357
3088
  {
2358
- nth0 = 32;
2359
- nth1 = 1;
3089
+ nsg = 1;
3090
+ nr0 = 1;
2360
3091
  if (src1t == GGML_TYPE_F32) {
2361
3092
  if (ne11 * ne12 < 4) {
2362
3093
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2363
3094
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2364
3095
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2365
- nrows = ne11;
3096
+ nr1 = ne11;
2366
3097
  } else {
2367
3098
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2368
- nrows = 4;
3099
+ nr1 = 4;
2369
3100
  }
2370
3101
  } else {
2371
3102
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2372
- nrows = 4;
3103
+ nr1 = 4;
2373
3104
  }
2374
3105
  } break;
2375
3106
  case GGML_TYPE_Q4_0:
2376
3107
  {
2377
- nth0 = 8;
2378
- nth1 = 8;
3108
+ nsg = N_SG_Q4_0;
3109
+ nr0 = N_R0_Q4_0;
2379
3110
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2380
3111
  } break;
2381
3112
  case GGML_TYPE_Q4_1:
2382
3113
  {
2383
- nth0 = 8;
2384
- nth1 = 8;
3114
+ nsg = N_SG_Q4_1;
3115
+ nr0 = N_R0_Q4_1;
2385
3116
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2386
3117
  } break;
2387
3118
  case GGML_TYPE_Q5_0:
2388
3119
  {
2389
- nth0 = 8;
2390
- nth1 = 8;
3120
+ nsg = N_SG_Q5_0;
3121
+ nr0 = N_R0_Q5_0;
2391
3122
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2392
3123
  } break;
2393
3124
  case GGML_TYPE_Q5_1:
2394
3125
  {
2395
- nth0 = 8;
2396
- nth1 = 8;
3126
+ nsg = N_SG_Q5_1;
3127
+ nr0 = N_R0_Q5_1;
2397
3128
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2398
3129
  } break;
2399
3130
  case GGML_TYPE_Q8_0:
2400
3131
  {
2401
- nth0 = 8;
2402
- nth1 = 8;
3132
+ nsg = N_SG_Q8_0;
3133
+ nr0 = N_R0_Q8_0;
2403
3134
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2404
3135
  } break;
2405
3136
  case GGML_TYPE_Q2_K:
2406
3137
  {
2407
- nth0 = 2;
2408
- nth1 = 32;
3138
+ nsg = N_SG_Q2_K;
3139
+ nr0 = N_R0_Q2_K;
2409
3140
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2410
3141
  } break;
2411
3142
  case GGML_TYPE_Q3_K:
2412
3143
  {
2413
- nth0 = 2;
2414
- nth1 = 32;
3144
+ nsg = N_SG_Q3_K;
3145
+ nr0 = N_R0_Q3_K;
2415
3146
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2416
3147
  } break;
2417
3148
  case GGML_TYPE_Q4_K:
2418
3149
  {
2419
- nth0 = 4; //1;
2420
- nth1 = 8; //32;
3150
+ nsg = N_SG_Q4_K;
3151
+ nr0 = N_R0_Q4_K;
2421
3152
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2422
3153
  } break;
2423
3154
  case GGML_TYPE_Q5_K:
2424
3155
  {
2425
- nth0 = 2;
2426
- nth1 = 32;
3156
+ nsg = N_SG_Q5_K;
3157
+ nr0 = N_R0_Q5_K;
2427
3158
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2428
3159
  } break;
2429
3160
  case GGML_TYPE_Q6_K:
2430
3161
  {
2431
- nth0 = 2;
2432
- nth1 = 32;
3162
+ nsg = N_SG_Q6_K;
3163
+ nr0 = N_R0_Q6_K;
2433
3164
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2434
3165
  } break;
2435
3166
  case GGML_TYPE_IQ2_XXS:
2436
3167
  {
2437
- nth0 = 4;
2438
- nth1 = 16;
3168
+ nsg = N_SG_IQ2_XXS;
3169
+ nr0 = N_R0_IQ2_XXS;
3170
+ smem = 256*8+128;
2439
3171
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2440
3172
  } break;
2441
3173
  case GGML_TYPE_IQ2_XS:
2442
3174
  {
2443
- nth0 = 4;
2444
- nth1 = 16;
3175
+ nsg = N_SG_IQ2_XS;
3176
+ nr0 = N_R0_IQ2_XS;
3177
+ smem = 512*8+128;
2445
3178
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2446
3179
  } break;
2447
3180
  case GGML_TYPE_IQ3_XXS:
2448
3181
  {
2449
- nth0 = 4;
2450
- nth1 = 16;
3182
+ nsg = N_SG_IQ3_XXS;
3183
+ nr0 = N_R0_IQ3_XXS;
3184
+ smem = 256*4+128;
2451
3185
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2452
3186
  } break;
2453
3187
  case GGML_TYPE_IQ3_S:
2454
3188
  {
2455
- nth0 = 4;
2456
- nth1 = 16;
3189
+ nsg = N_SG_IQ3_S;
3190
+ nr0 = N_R0_IQ3_S;
3191
+ smem = 512*4;
2457
3192
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2458
3193
  } break;
2459
3194
  case GGML_TYPE_IQ2_S:
2460
3195
  {
2461
- nth0 = 4;
2462
- nth1 = 16;
3196
+ nsg = N_SG_IQ2_S;
3197
+ nr0 = N_R0_IQ2_S;
2463
3198
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2464
3199
  } break;
2465
3200
  case GGML_TYPE_IQ1_S:
2466
3201
  {
2467
- nth0 = 4;
2468
- nth1 = 16;
3202
+ nsg = N_SG_IQ1_S;
3203
+ nr0 = N_R0_IQ1_S;
2469
3204
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2470
3205
  } break;
2471
3206
  case GGML_TYPE_IQ1_M:
2472
3207
  {
2473
- nth0 = 4;
2474
- nth1 = 16;
3208
+ nsg = N_SG_IQ1_M;
3209
+ nr0 = N_R0_IQ1_M;
2475
3210
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2476
3211
  } break;
2477
3212
  case GGML_TYPE_IQ4_NL:
2478
3213
  {
2479
- nth0 = 4;
2480
- nth1 = 16;
3214
+ nsg = N_SG_IQ4_NL;
3215
+ nr0 = N_R0_IQ4_NL;
3216
+ smem = 32*sizeof(float);
2481
3217
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2482
3218
  } break;
2483
3219
  case GGML_TYPE_IQ4_XS:
2484
3220
  {
2485
- nth0 = 4;
2486
- nth1 = 16;
3221
+ nsg = N_SG_IQ4_XS;
3222
+ nr0 = N_R0_IQ4_XS;
3223
+ smem = 32*sizeof(float);
2487
3224
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2488
3225
  } break;
2489
3226
  default:
@@ -2520,47 +3257,14 @@ static void ggml_metal_encode_node(
2520
3257
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2521
3258
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2522
3259
 
2523
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2524
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2525
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2526
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2527
- }
2528
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2529
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2530
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2531
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2532
- }
2533
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2534
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2535
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2536
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2537
- }
2538
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2539
- const int mem_size = 32*sizeof(float);
2540
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2541
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2542
- }
2543
- else if (src0t == GGML_TYPE_Q4_K) {
2544
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2545
- }
2546
- else if (src0t == GGML_TYPE_Q3_K) {
2547
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2548
- }
2549
- else if (src0t == GGML_TYPE_Q5_K) {
2550
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2551
- }
2552
- else if (src0t == GGML_TYPE_Q6_K) {
2553
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2554
- } else {
2555
- const int64_t ny = (ne11 + nrows - 1)/nrows;
2556
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3260
+ if (smem > 0) {
3261
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
2557
3262
  }
3263
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2558
3264
  }
2559
3265
  } break;
2560
3266
  case GGML_OP_MUL_MAT_ID:
2561
3267
  {
2562
- const int n_as = src0->ne[2];
2563
-
2564
3268
  // src2 = ids
2565
3269
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
2566
3270
 
@@ -2574,26 +3278,22 @@ static void ggml_metal_encode_node(
2574
3278
  GGML_ASSERT(ne03 == 1);
2575
3279
  GGML_ASSERT(ne13 == 1);
2576
3280
 
3281
+ const uint32_t r2 = 1;
3282
+ const uint32_t r3 = 1;
3283
+
2577
3284
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2578
3285
  // to the matrix-vector kernel
2579
3286
  // ne20 = n_used_experts
2580
- // ne21 = n_rows
2581
- const int dst_rows = ne20*ne21;
2582
- const int dst_rows_min = n_as;
2583
- const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
2584
-
2585
- // max size of the rowids array in the kernel shared buffer
2586
- GGML_ASSERT(dst_rows <= dst_rows_max);
3287
+ // ne21 = n_rows (batch size)
3288
+ const int ne21_mm_id_min = 32;
2587
3289
 
2588
3290
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
2589
3291
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
2590
- // !!!
2591
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
2592
- // indirect matrix multiplication
2593
- // !!!
2594
3292
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
2595
3293
  ne00 % 32 == 0 && ne00 >= 64 &&
2596
- dst_rows > dst_rows_min) {
3294
+ (ne21 >= ne21_mm_id_min)) {
3295
+ GGML_ASSERT(ne00 % 4 == 0);
3296
+
2597
3297
  // some Metal matrix data types require aligned pointers
2598
3298
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2599
3299
  switch (src0->type) {
@@ -2603,203 +3303,319 @@ static void ggml_metal_encode_node(
2603
3303
  default: break;
2604
3304
  }
2605
3305
 
2606
- id<MTLComputePipelineState> pipeline = nil;
3306
+ const int64_t neh10 = ne10; // n_embd
3307
+ const int64_t neh11 = ne21; // n_tokens
3308
+ const int64_t neh12 = ne02; // n_expert
2607
3309
 
2608
- switch (src0->type) {
2609
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
2610
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
2611
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
2612
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
2613
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
2614
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
2615
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
2616
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
2617
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
2618
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
2619
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
2620
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
2621
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
2622
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
2623
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
2624
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
2625
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
2626
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
2627
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
2628
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
2629
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
2630
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
2631
- default: GGML_ABORT("MUL_MAT_ID not implemented");
3310
+ const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
3311
+ const uint64_t nbh11 = nbh10*neh10;
3312
+ const uint64_t nbh12 = nbh11*neh11;
3313
+ const uint64_t nbh13 = nbh12*neh12;
3314
+
3315
+ const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
3316
+ id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3317
+ if (!h_src1) {
3318
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3319
+ return false;
2632
3320
  }
2633
3321
 
2634
- ggml_metal_kargs_mul_mm_id args = {
2635
- /*.nei0 =*/ ne20,
2636
- /*.nei1 =*/ ne21,
2637
- /*.nbi1 =*/ nb21,
2638
- /*.ne00 =*/ ne00,
2639
- /*.ne02 =*/ ne02,
2640
- /*.nb01 =*/ nb01,
2641
- /*.nb02 =*/ nb02,
2642
- /*.ne11 =*/ ne11,
2643
- /*.ne12 =*/ ne12,
2644
- /*.ne13 =*/ ne13,
2645
- /*.nb10 =*/ nb10,
2646
- /*.nb11 =*/ nb11,
2647
- /*.nb12 =*/ nb12,
2648
- /*.ne0 =*/ ne0,
2649
- /*.ne1 =*/ ne1,
2650
- };
3322
+ const int64_t neh0 = ne0;
3323
+ const int64_t neh1 = ne21;
3324
+ const int64_t neh2 = ne02;
3325
+
3326
+ const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
3327
+ const uint64_t nbh1 = nbh0*neh0;
3328
+ const uint64_t nbh2 = nbh1*neh1;
3329
+ //const uint64_t nbh3 = nbh2*neh2;
3330
+
3331
+ const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
3332
+ id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3333
+ if (!h_dst) {
3334
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3335
+ return false;
3336
+ }
3337
+
3338
+ // tokens per expert
3339
+ const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
3340
+ id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3341
+ if (!h_tpe) {
3342
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3343
+ return false;
3344
+ }
3345
+
3346
+ // id map
3347
+ // [n_expert_used, n_tokens]
3348
+ const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
3349
+ id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3350
+ if (!h_ids) {
3351
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3352
+ return false;
3353
+ }
3354
+
3355
+ {
3356
+ const int nth = MIN(1024, ne10/4);
3357
+
3358
+ ggml_metal_kargs_mul_mm_id_map0 args = {
3359
+ ne10,
3360
+ ne11, // n_expert_used (bcast)
3361
+ nb11,
3362
+ nb12,
3363
+ neh11, // n_tokens
3364
+ nbh11,
3365
+ ne20, // n_expert_used
3366
+ nb21,
3367
+ };
3368
+
3369
+ id<MTLComputePipelineState> pipeline = nil;
3370
+
3371
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
3372
+
3373
+ [encoder setComputePipelineState:pipeline];
3374
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3375
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3376
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3377
+ [encoder setBuffer: h_src1 offset:0 atIndex:3];
3378
+ [encoder setBuffer: h_tpe offset:0 atIndex:4];
3379
+ [encoder setBuffer: h_ids offset:0 atIndex:5];
3380
+
3381
+ [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3382
+ }
3383
+
3384
+ {
3385
+ id<MTLComputePipelineState> pipeline = nil;
3386
+
3387
+ switch (src0->type) {
3388
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
3389
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
3390
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
3391
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
3392
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
3393
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
3394
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
3395
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
3396
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
3397
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
3398
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
3399
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
3400
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
3401
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
3402
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
3403
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
3404
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
3405
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
3406
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
3407
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
3408
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
3409
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
3410
+ default: GGML_ABORT("MUL_MAT_ID not implemented");
3411
+ }
2651
3412
 
2652
- [encoder setComputePipelineState:pipeline];
2653
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
2654
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2655
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2656
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2657
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3413
+ ggml_metal_kargs_mul_mm_id args = {
3414
+ /*.ne00 =*/ ne00,
3415
+ /*.ne02 =*/ ne02,
3416
+ /*.nb01 =*/ nb01,
3417
+ /*.nb02 =*/ nb02,
3418
+ /*.nb03 =*/ nb03,
3419
+ /*.neh12 =*/ neh12,
3420
+ /*.nbh10 =*/ nbh10,
3421
+ /*.nbh11 =*/ nbh11,
3422
+ /*.nbh12 =*/ nbh12,
3423
+ /*.nbh13 =*/ nbh13,
3424
+ /*.neh0 =*/ neh0,
3425
+ /*.neh1 =*/ neh1,
3426
+ /*.r2 =*/ r2,
3427
+ /*.r3 =*/ r3,
3428
+ };
3429
+
3430
+ [encoder setComputePipelineState:pipeline];
3431
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3432
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3433
+ [encoder setBuffer: h_src1 offset:0 atIndex:2];
3434
+ [encoder setBuffer: h_tpe offset:0 atIndex:3];
3435
+ [encoder setBuffer: h_dst offset:0 atIndex:4];
3436
+
3437
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
3438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3439
+ }
2658
3440
 
2659
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
3441
+ {
3442
+ GGML_ASSERT(ne0 % 4 == 0);
2660
3443
 
2661
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2662
- } else {
2663
- int nth0 = 32;
2664
- int nth1 = 1;
2665
- int nrows = 1;
2666
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
3444
+ const int nth = MIN(1024, ne0/4);
2667
3445
 
3446
+ ggml_metal_kargs_mul_mm_id_map1 args = {
3447
+ ne20, // n_expert_used
3448
+ neh0,
3449
+ neh1,
3450
+ nbh1,
3451
+ nbh2,
3452
+ ne0,
3453
+ nb1,
3454
+ nb2,
3455
+ };
3456
+
3457
+ id<MTLComputePipelineState> pipeline = nil;
3458
+
3459
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
3460
+
3461
+ [encoder setComputePipelineState:pipeline];
3462
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3463
+ [encoder setBuffer: h_dst offset:0 atIndex:1];
3464
+ [encoder setBuffer: h_ids offset:0 atIndex:2];
3465
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3466
+
3467
+ [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3468
+ }
3469
+ } else {
2668
3470
  id<MTLComputePipelineState> pipeline = nil;
2669
3471
 
3472
+ int nsg = 0; // number of simdgroups
3473
+ int nr0 = 0; // number of src0 rows per simdgroup
3474
+ int nr1 = 1; // number of src1 rows per threadgroup
3475
+
3476
+ size_t smem = 0; // shared memory
3477
+
2670
3478
  // use custom matrix x vector kernel
2671
3479
  switch (src0t) {
2672
3480
  case GGML_TYPE_F32:
2673
3481
  {
2674
3482
  GGML_ASSERT(src1t == GGML_TYPE_F32);
3483
+ nsg = 1;
3484
+ nr0 = 1;
2675
3485
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2676
3486
  } break;
2677
3487
  case GGML_TYPE_F16:
2678
3488
  {
2679
3489
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2680
- nth0 = 32;
2681
- nth1 = 1;
3490
+ nsg = 1;
3491
+ nr0 = 1;
2682
3492
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
2683
3493
  } break;
2684
3494
  case GGML_TYPE_BF16:
2685
3495
  {
2686
3496
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2687
- nth0 = 32;
2688
- nth1 = 1;
3497
+ nsg = 1;
3498
+ nr0 = 1;
2689
3499
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
2690
3500
  } break;
2691
3501
  case GGML_TYPE_Q4_0:
2692
3502
  {
2693
- nth0 = 8;
2694
- nth1 = 8;
3503
+ nsg = N_SG_Q4_0;
3504
+ nr0 = N_R0_Q4_0;
2695
3505
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
2696
3506
  } break;
2697
3507
  case GGML_TYPE_Q4_1:
2698
3508
  {
2699
- nth0 = 8;
2700
- nth1 = 8;
3509
+ nsg = N_SG_Q4_1;
3510
+ nr0 = N_R0_Q4_1;
2701
3511
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
2702
3512
  } break;
2703
3513
  case GGML_TYPE_Q5_0:
2704
3514
  {
2705
- nth0 = 8;
2706
- nth1 = 8;
3515
+ nsg = N_SG_Q5_0;
3516
+ nr0 = N_R0_Q5_0;
2707
3517
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
2708
3518
  } break;
2709
3519
  case GGML_TYPE_Q5_1:
2710
3520
  {
2711
- nth0 = 8;
2712
- nth1 = 8;
3521
+ nsg = N_SG_Q5_1;
3522
+ nr0 = N_R0_Q5_1;
2713
3523
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
2714
3524
  } break;
2715
3525
  case GGML_TYPE_Q8_0:
2716
3526
  {
2717
- nth0 = 8;
2718
- nth1 = 8;
3527
+ nsg = N_SG_Q8_0;
3528
+ nr0 = N_R0_Q8_0;
2719
3529
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
2720
3530
  } break;
2721
3531
  case GGML_TYPE_Q2_K:
2722
3532
  {
2723
- nth0 = 2;
2724
- nth1 = 32;
3533
+ nsg = N_SG_Q2_K;
3534
+ nr0 = N_R0_Q2_K;
2725
3535
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
2726
3536
  } break;
2727
3537
  case GGML_TYPE_Q3_K:
2728
3538
  {
2729
- nth0 = 2;
2730
- nth1 = 32;
3539
+ nsg = N_SG_Q3_K;
3540
+ nr0 = N_R0_Q3_K;
2731
3541
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
2732
3542
  } break;
2733
3543
  case GGML_TYPE_Q4_K:
2734
3544
  {
2735
- nth0 = 4; //1;
2736
- nth1 = 8; //32;
3545
+ nsg = N_SG_Q4_K;
3546
+ nr0 = N_R0_Q4_K;
2737
3547
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
2738
3548
  } break;
2739
3549
  case GGML_TYPE_Q5_K:
2740
3550
  {
2741
- nth0 = 2;
2742
- nth1 = 32;
3551
+ nsg = N_SG_Q5_K;
3552
+ nr0 = N_R0_Q5_K;
2743
3553
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
2744
3554
  } break;
2745
3555
  case GGML_TYPE_Q6_K:
2746
3556
  {
2747
- nth0 = 2;
2748
- nth1 = 32;
3557
+ nsg = N_SG_Q6_K;
3558
+ nr0 = N_R0_Q6_K;
2749
3559
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
2750
3560
  } break;
2751
3561
  case GGML_TYPE_IQ2_XXS:
2752
3562
  {
2753
- nth0 = 4;
2754
- nth1 = 16;
3563
+ nsg = N_SG_IQ2_XXS;
3564
+ nr0 = N_R0_IQ2_XXS;
3565
+ smem = 256*8+128;
2755
3566
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2756
3567
  } break;
2757
3568
  case GGML_TYPE_IQ2_XS:
2758
3569
  {
2759
- nth0 = 4;
2760
- nth1 = 16;
3570
+ nsg = N_SG_IQ2_XS;
3571
+ nr0 = N_R0_IQ2_XS;
3572
+ smem = 512*8+128;
2761
3573
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
2762
3574
  } break;
2763
3575
  case GGML_TYPE_IQ3_XXS:
2764
3576
  {
2765
- nth0 = 4;
2766
- nth1 = 16;
3577
+ nsg = N_SG_IQ3_XXS;
3578
+ nr0 = N_R0_IQ3_XXS;
3579
+ smem = 256*4+128;
2767
3580
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
2768
3581
  } break;
2769
3582
  case GGML_TYPE_IQ3_S:
2770
3583
  {
2771
- nth0 = 4;
2772
- nth1 = 16;
3584
+ nsg = N_SG_IQ3_S;
3585
+ nr0 = N_R0_IQ3_S;
3586
+ smem = 512*4;
2773
3587
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
2774
3588
  } break;
2775
3589
  case GGML_TYPE_IQ2_S:
2776
3590
  {
2777
- nth0 = 4;
2778
- nth1 = 16;
3591
+ nsg = N_SG_IQ2_S;
3592
+ nr0 = N_R0_IQ2_S;
2779
3593
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
2780
3594
  } break;
2781
3595
  case GGML_TYPE_IQ1_S:
2782
3596
  {
2783
- nth0 = 4;
2784
- nth1 = 16;
3597
+ nsg = N_SG_IQ1_S;
3598
+ nr0 = N_R0_IQ1_S;
2785
3599
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
2786
3600
  } break;
2787
3601
  case GGML_TYPE_IQ1_M:
2788
3602
  {
2789
- nth0 = 4;
2790
- nth1 = 16;
3603
+ nsg = N_SG_IQ1_M;
3604
+ nr0 = N_R0_IQ1_M;
2791
3605
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
2792
3606
  } break;
2793
3607
  case GGML_TYPE_IQ4_NL:
2794
3608
  {
2795
- nth0 = 4;
2796
- nth1 = 16;
3609
+ nsg = N_SG_IQ4_NL;
3610
+ nr0 = N_R0_IQ4_NL;
3611
+ smem = 32*sizeof(float);
2797
3612
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2798
3613
  } break;
2799
3614
  case GGML_TYPE_IQ4_XS:
2800
3615
  {
2801
- nth0 = 4;
2802
- nth1 = 16;
3616
+ nsg = N_SG_IQ4_XS;
3617
+ nr0 = N_R0_IQ4_XS;
3618
+ smem = 32*sizeof(float);
2803
3619
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2804
3620
  } break;
2805
3621
  default:
@@ -2810,7 +3626,7 @@ static void ggml_metal_encode_node(
2810
3626
  };
2811
3627
 
2812
3628
  if (ggml_is_quantized(src0t)) {
2813
- GGML_ASSERT(ne00 >= nth0*nth1);
3629
+ GGML_ASSERT(ne00 >= nsg*nr0);
2814
3630
  }
2815
3631
 
2816
3632
  ggml_metal_kargs_mul_mv_id args = {
@@ -2843,43 +3659,12 @@ static void ggml_metal_encode_node(
2843
3659
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
2844
3660
 
2845
3661
  const int64_t _ne1 = 1;
2846
- const int tgz = dst_rows;
3662
+ const int64_t ne123 = ne20*ne21;
2847
3663
 
2848
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2849
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2850
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2851
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2852
- }
2853
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2854
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2855
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2856
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2857
- }
2858
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2859
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2860
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2861
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2862
- }
2863
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2864
- const int mem_size = 32*sizeof(float);
2865
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2866
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2867
- }
2868
- else if (src0t == GGML_TYPE_Q4_K) {
2869
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2870
- }
2871
- else if (src0t == GGML_TYPE_Q3_K) {
2872
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2873
- }
2874
- else if (src0t == GGML_TYPE_Q5_K) {
2875
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2876
- }
2877
- else if (src0t == GGML_TYPE_Q6_K) {
2878
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2879
- } else {
2880
- const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2881
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3664
+ if (smem > 0) {
3665
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
2882
3666
  }
3667
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2883
3668
  }
2884
3669
  } break;
2885
3670
  case GGML_OP_GET_ROWS:
@@ -2913,19 +3698,22 @@ static void ggml_metal_encode_node(
2913
3698
  default: GGML_ABORT("not implemented");
2914
3699
  }
2915
3700
 
2916
- // TODO: add ggml_metal_kargs struct
3701
+ ggml_metal_kargs_get_rows args = {
3702
+ /*.ne00 =*/ ne00,
3703
+ /*.nb01 =*/ nb01,
3704
+ /*.nb02 =*/ nb02,
3705
+ /*.ne10 =*/ ne10,
3706
+ /*.nb10 =*/ nb10,
3707
+ /*.nb11 =*/ nb11,
3708
+ /*.nb1 =*/ nb1,
3709
+ /*.nb2 =*/ nb2,
3710
+ };
3711
+
2917
3712
  [encoder setComputePipelineState:pipeline];
2918
3713
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2919
3714
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2920
3715
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2921
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2922
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2923
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
2924
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
2925
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
2926
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
2927
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
2928
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
3716
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2929
3717
 
2930
3718
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2931
3719
  } break;
@@ -2963,6 +3751,42 @@ static void ggml_metal_encode_node(
2963
3751
 
2964
3752
  const int64_t nrows = ggml_nrows(src0);
2965
3753
 
3754
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3755
+ } break;
3756
+ case GGML_OP_L2_NORM:
3757
+ {
3758
+ GGML_ASSERT(ne00 % 4 == 0);
3759
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3760
+
3761
+ float eps;
3762
+ memcpy(&eps, dst->op_params, sizeof(float));
3763
+
3764
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
3765
+
3766
+ int nth = 32; // SIMD width
3767
+
3768
+ while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3769
+ nth *= 2;
3770
+ }
3771
+
3772
+ nth = MIN(nth, ne00/4);
3773
+
3774
+ ggml_metal_kargs_l2_norm args = {
3775
+ /*.ne00 =*/ ne00,
3776
+ /*.ne00_4 =*/ ne00/4,
3777
+ /*.nb01 =*/ nb01,
3778
+ /*.eps =*/ eps,
3779
+ };
3780
+
3781
+ [encoder setComputePipelineState:pipeline];
3782
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3783
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3784
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3785
+
3786
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3787
+
3788
+ const int64_t nrows = ggml_nrows(src0);
3789
+
2966
3790
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2967
3791
  } break;
2968
3792
  case GGML_OP_GROUP_NORM:
@@ -2982,18 +3806,21 @@ static void ggml_metal_encode_node(
2982
3806
 
2983
3807
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
2984
3808
 
2985
- // TODO: add ggml_metal_kargs struct
3809
+ ggml_metal_kargs_group_norm args = {
3810
+ /*.ne00 =*/ ne00,
3811
+ /*.ne01 =*/ ne01,
3812
+ /*.ne02 =*/ ne02,
3813
+ /*.nb00 =*/ nb00,
3814
+ /*.nb01 =*/ nb01,
3815
+ /*.nb02 =*/ nb02,
3816
+ /*.n_groups =*/ n_groups,
3817
+ /*.eps =*/ eps,
3818
+ };
3819
+
2986
3820
  [encoder setComputePipelineState:pipeline];
2987
3821
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2988
3822
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2989
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2990
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2991
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2992
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2993
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2994
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2995
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2996
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
3823
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
2997
3824
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2998
3825
 
2999
3826
  [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -3036,6 +3863,7 @@ static void ggml_metal_encode_node(
3036
3863
  } break;
3037
3864
  case GGML_OP_ROPE:
3038
3865
  {
3866
+
3039
3867
  // make sure we have one or more position id(ne10) per token(ne02)
3040
3868
  GGML_ASSERT(ne10 % ne02 == 0);
3041
3869
  GGML_ASSERT(ne10 >= ne02);
@@ -3062,20 +3890,42 @@ static void ggml_metal_encode_node(
3062
3890
  memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
3063
3891
  memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
3064
3892
 
3065
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3893
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3894
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
3895
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
3896
+
3897
+ // mrope
3898
+ const int sect_0 = ((const int32_t *) dst->op_params)[11];
3899
+ const int sect_1 = ((const int32_t *) dst->op_params)[12];
3900
+ const int sect_2 = ((const int32_t *) dst->op_params)[13];
3901
+ const int sect_3 = ((const int32_t *) dst->op_params)[14];
3066
3902
 
3067
3903
  id<MTLComputePipelineState> pipeline = nil;
3068
3904
 
3069
- if (!is_neox) {
3905
+ if (is_neox) {
3070
3906
  switch (src0->type) {
3071
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3072
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3907
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3908
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3909
+ default: GGML_ABORT("fatal error");
3910
+ };
3911
+ } else if (is_mrope && !is_vision) {
3912
+ GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3913
+ switch (src0->type) {
3914
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
3915
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
3916
+ default: GGML_ABORT("fatal error");
3917
+ };
3918
+ } else if (is_vision) {
3919
+ GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3920
+ switch (src0->type) {
3921
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
3922
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
3073
3923
  default: GGML_ABORT("fatal error");
3074
3924
  };
3075
3925
  } else {
3076
3926
  switch (src0->type) {
3077
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3078
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3927
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3928
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3079
3929
  default: GGML_ABORT("fatal error");
3080
3930
  };
3081
3931
  }
@@ -3106,6 +3956,10 @@ static void ggml_metal_encode_node(
3106
3956
  /*.attn_factor =*/ attn_factor,
3107
3957
  /*.beta_fast =*/ beta_fast,
3108
3958
  /*.beta_slow =*/ beta_slow,
3959
+ /* sect_0 =*/ sect_0,
3960
+ /* sect_1 =*/ sect_1,
3961
+ /* sect_2 =*/ sect_2,
3962
+ /* sect_3 =*/ sect_3,
3109
3963
  };
3110
3964
 
3111
3965
  [encoder setComputePipelineState:pipeline];
@@ -3151,8 +4005,8 @@ static void ggml_metal_encode_node(
3151
4005
 
3152
4006
  const int32_t CHW = IC * KH * KW;
3153
4007
 
3154
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
3155
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
4008
+ const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
4009
+ const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
3156
4010
 
3157
4011
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
3158
4012
 
@@ -3174,27 +4028,30 @@ static void ggml_metal_encode_node(
3174
4028
  default: GGML_ABORT("fatal error");
3175
4029
  };
3176
4030
 
3177
- // TODO: add ggml_metal_kargs struct
4031
+ ggml_metal_kargs_im2col args = {
4032
+ /*.ofs0 =*/ ofs0,
4033
+ /*.ofs1 =*/ ofs1,
4034
+ /*.IW =*/ IW,
4035
+ /*.IH =*/ IH,
4036
+ /*.CHW =*/ CHW,
4037
+ /*.s0 =*/ s0,
4038
+ /*.s1 =*/ s1,
4039
+ /*.p0 =*/ p0,
4040
+ /*.p1 =*/ p1,
4041
+ /*.d0 =*/ d0,
4042
+ /*.d1 =*/ d1,
4043
+ /*.N =*/ N,
4044
+ /*.KH =*/ KH,
4045
+ /*.KW =*/ KW,
4046
+ /*.KHW =*/ KH * KW,
4047
+ };
4048
+
3178
4049
  [encoder setComputePipelineState:pipeline];
3179
4050
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
3180
4051
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3181
- [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
3182
- [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
3183
- [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
3184
- [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
3185
- [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
3186
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
3187
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
3188
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
3189
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
3190
- [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
3191
- [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
4052
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3192
4053
 
3193
4054
  if (is_gt_mttpt) {
3194
- [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
3195
- [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
3196
- [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
3197
-
3198
4055
  const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
3199
4056
 
3200
4057
  const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
@@ -3234,16 +4091,20 @@ static void ggml_metal_encode_node(
3234
4091
  default: GGML_ABORT("fatal error");
3235
4092
  };
3236
4093
 
4094
+ ggml_metal_kargs_conv_transpose_1d args = {
4095
+ /*.IC =*/ IC,
4096
+ /*.IL =*/ IL,
4097
+ /*.K =*/ K,
4098
+ /*.s0 =*/ s0,
4099
+ /*.nb0 =*/ nb0,
4100
+ /*.nb1 =*/ nb1,
4101
+ };
4102
+
3237
4103
  [encoder setComputePipelineState:pipeline];
3238
4104
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3239
4105
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3240
4106
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3241
- [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
3242
- [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
3243
- [encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
3244
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
3245
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
3246
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
4107
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
3247
4108
 
3248
4109
  [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3249
4110
  } break;
@@ -3258,30 +4119,33 @@ static void ggml_metal_encode_node(
3258
4119
 
3259
4120
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
3260
4121
 
3261
- // TODO: add ggml_metal_kargs struct
4122
+ ggml_metal_kargs_upscale args = {
4123
+ /*.ne00 =*/ ne00,
4124
+ /*.ne01 =*/ ne01,
4125
+ /*.ne02 =*/ ne02,
4126
+ /*.ne03 =*/ ne03,
4127
+ /*.nb00 =*/ nb00,
4128
+ /*.nb01 =*/ nb01,
4129
+ /*.nb02 =*/ nb02,
4130
+ /*.nb03 =*/ nb03,
4131
+ /*.ne0 =*/ ne0,
4132
+ /*.ne1 =*/ ne1,
4133
+ /*.ne2 =*/ ne2,
4134
+ /*.ne3 =*/ ne3,
4135
+ /*.nb0 =*/ nb0,
4136
+ /*.nb1 =*/ nb1,
4137
+ /*.nb2 =*/ nb2,
4138
+ /*.nb3 =*/ nb3,
4139
+ /*.sf0 =*/ sf0,
4140
+ /*.sf1 =*/ sf1,
4141
+ /*.sf2 =*/ sf2,
4142
+ /*.sf3 =*/ sf3
4143
+ };
4144
+
3262
4145
  [encoder setComputePipelineState:pipeline];
3263
4146
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3264
4147
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3265
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3266
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3267
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3268
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3269
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
3270
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
3271
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
3272
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
3273
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
3274
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
3275
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
3276
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
3277
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
3278
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
3279
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
3280
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
3281
- [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
3282
- [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
3283
- [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
3284
- [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
4148
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3285
4149
 
3286
4150
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
3287
4151
 
@@ -3293,26 +4157,29 @@ static void ggml_metal_encode_node(
3293
4157
 
3294
4158
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
3295
4159
 
3296
- // TODO: add ggml_metal_kargs struct
4160
+ ggml_metal_kargs_pad args = {
4161
+ /*.ne00 =*/ ne00,
4162
+ /*.ne01 =*/ ne01,
4163
+ /*.ne02 =*/ ne02,
4164
+ /*.ne03 =*/ ne03,
4165
+ /*.nb00 =*/ nb00,
4166
+ /*.nb01 =*/ nb01,
4167
+ /*.nb02 =*/ nb02,
4168
+ /*.nb03 =*/ nb03,
4169
+ /*.ne0 =*/ ne0,
4170
+ /*.ne1 =*/ ne1,
4171
+ /*.ne2 =*/ ne2,
4172
+ /*.ne3 =*/ ne3,
4173
+ /*.nb0 =*/ nb0,
4174
+ /*.nb1 =*/ nb1,
4175
+ /*.nb2 =*/ nb2,
4176
+ /*.nb3 =*/ nb3
4177
+ };
4178
+
3297
4179
  [encoder setComputePipelineState:pipeline];
3298
4180
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3299
4181
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3300
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3301
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3302
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3303
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3304
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
3305
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
3306
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
3307
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
3308
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
3309
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
3310
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
3311
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
3312
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
3313
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
3314
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
3315
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
4182
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3316
4183
 
3317
4184
  const int nth = MIN(1024, ne0);
3318
4185
 
@@ -3327,24 +4194,31 @@ static void ggml_metal_encode_node(
3327
4194
 
3328
4195
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
3329
4196
 
4197
+ ggml_metal_kargs_pad_reflect_1d args = {
4198
+ /*.ne00 =*/ ne00,
4199
+ /*.ne01 =*/ ne01,
4200
+ /*.ne02 =*/ ne02,
4201
+ /*.ne03 =*/ ne03,
4202
+ /*.nb00 =*/ nb00,
4203
+ /*.nb01 =*/ nb01,
4204
+ /*.nb02 =*/ nb02,
4205
+ /*.nb03 =*/ nb03,
4206
+ /*.ne0 =*/ ne0,
4207
+ /*.ne1 =*/ ne1,
4208
+ /*.ne2 =*/ ne2,
4209
+ /*.ne3 =*/ ne3,
4210
+ /*.nb0 =*/ nb0,
4211
+ /*.nb1 =*/ nb1,
4212
+ /*.nb2 =*/ nb2,
4213
+ /*.nb3 =*/ nb3,
4214
+ /*.p0 =*/ p0,
4215
+ /*.p1 =*/ p1
4216
+ };
4217
+
3330
4218
  [encoder setComputePipelineState:pipeline];
3331
4219
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3332
4220
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3333
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3334
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3335
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3336
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3337
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
3338
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
3339
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
3340
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
3341
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
3342
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
3343
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
3344
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
3345
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
3346
- [encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
3347
- [encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
4221
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3348
4222
 
3349
4223
  const int nth = MIN(1024, ne0);
3350
4224
 
@@ -3362,12 +4236,15 @@ static void ggml_metal_encode_node(
3362
4236
 
3363
4237
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
3364
4238
 
3365
- // TODO: add ggml_metal_kargs struct
4239
+ ggml_metal_kargs_arange args = {
4240
+ /*.ne0 =*/ ne0,
4241
+ /*.start =*/ start,
4242
+ /*.step =*/ step
4243
+ };
4244
+
3366
4245
  [encoder setComputePipelineState:pipeline];
3367
- [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
3368
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
3369
- [encoder setBytes:&start length:sizeof(start) atIndex:2];
3370
- [encoder setBytes:&step length:sizeof(step) atIndex:3];
4246
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
4247
+ [encoder setBytes:&args length:sizeof(args) atIndex:1];
3371
4248
 
3372
4249
  const int nth = MIN(1024, ne0);
3373
4250
 
@@ -3384,13 +4261,16 @@ static void ggml_metal_encode_node(
3384
4261
 
3385
4262
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
3386
4263
 
3387
- // TODO: add ggml_metal_kargs struct
4264
+ ggml_metal_kargs_timestep_embedding args = {
4265
+ /*.nb1 =*/ nb1,
4266
+ /*.dim =*/ dim,
4267
+ /*.max_period =*/ max_period
4268
+ };
4269
+
3388
4270
  [encoder setComputePipelineState:pipeline];
3389
4271
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3390
4272
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3391
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
3392
- [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
3393
- [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
4273
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3394
4274
 
3395
4275
  const int nth = MIN(1024, half);
3396
4276
 
@@ -3423,12 +4303,15 @@ static void ggml_metal_encode_node(
3423
4303
  default: GGML_ABORT("fatal error");
3424
4304
  };
3425
4305
 
3426
- // TODO: add ggml_metal_kargs struct
4306
+ ggml_metal_kargs_argsort args = {
4307
+ /*.ncols =*/ ne00,
4308
+ /*.ncols_pad =*/ ne00_padded
4309
+ };
4310
+
3427
4311
  [encoder setComputePipelineState:pipeline];
3428
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3429
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3430
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3431
- [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
4312
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
4313
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
4314
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3432
4315
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3433
4316
 
3434
4317
  [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
@@ -3442,11 +4325,14 @@ static void ggml_metal_encode_node(
3442
4325
 
3443
4326
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
3444
4327
 
3445
- // TODO: add ggml_metal_kargs struct
4328
+ ggml_metal_kargs_leaky_relu args = {
4329
+ /*.slope =*/ slope
4330
+ };
4331
+
3446
4332
  [encoder setComputePipelineState:pipeline];
3447
4333
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3448
4334
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3449
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
4335
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
3450
4336
 
3451
4337
  const int64_t n = ggml_nelements(dst);
3452
4338
 
@@ -3460,7 +4346,9 @@ static void ggml_metal_encode_node(
3460
4346
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
3461
4347
  GGML_ASSERT(src1->type == src2->type);
3462
4348
 
3463
- GGML_ASSERT(ggml_are_same_shape (src1, src2));
4349
+ //GGML_ASSERT(ggml_are_same_shape (src1, src2));
4350
+ GGML_ASSERT(ne11 == ne21);
4351
+ GGML_ASSERT(ne12 == ne22);
3464
4352
 
3465
4353
  struct ggml_tensor * src3 = node->src[3];
3466
4354
 
@@ -3507,125 +4395,175 @@ static void ggml_metal_encode_node(
3507
4395
 
3508
4396
  // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3509
4397
  // for now avoiding mainly to keep the number of templates/kernels a bit lower
3510
- if (ne01 >= 4 || (ne00%128 != 0)) {
4398
+ // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
4399
+ if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
3511
4400
  switch (src1->type) {
3512
4401
  case GGML_TYPE_F16:
3513
4402
  {
3514
- switch (ne00) {
3515
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
3516
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
3517
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
3518
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
3519
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
3520
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
3521
- default:
3522
- {
3523
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3524
- GGML_LOG_ERROR("add template specialization for this size\n");
3525
- GGML_ABORT("add template specialization for this size");
3526
- }
4403
+ if (ne00 == 192 && ne20 == 128) {
4404
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
4405
+ } else if (ne00 == 576 && ne20 == 512) {
4406
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
4407
+ } else {
4408
+ switch (ne00) {
4409
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
4410
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
4411
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
4412
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
4413
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
4414
+ case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
4415
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
4416
+ default:
4417
+ {
4418
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4419
+ GGML_LOG_ERROR("add template specialization for this size\n");
4420
+ GGML_ABORT("add template specialization for this size");
4421
+ }
4422
+ }
3527
4423
  }
3528
4424
  } break;
3529
4425
  case GGML_TYPE_BF16:
3530
4426
  {
3531
- switch (ne00) {
3532
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3533
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3534
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3535
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3536
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3537
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3538
- default:
3539
- {
3540
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3541
- GGML_LOG_ERROR("add template specialization for this size\n");
3542
- GGML_ABORT("add template specialization for this size");
3543
- }
4427
+ if (ne00 == 192 && ne20 == 128) {
4428
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
4429
+ } else if (ne00 == 576 && ne20 == 512) {
4430
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
4431
+ } else {
4432
+ switch (ne00) {
4433
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
4434
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
4435
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
4436
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
4437
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
4438
+ case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
4439
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
4440
+ default:
4441
+ {
4442
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4443
+ GGML_LOG_ERROR("add template specialization for this size\n");
4444
+ GGML_ABORT("add template specialization for this size");
4445
+ }
4446
+ }
3544
4447
  }
3545
4448
  } break;
3546
4449
  case GGML_TYPE_Q4_0:
3547
4450
  {
3548
- switch (ne00) {
3549
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
3550
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
3551
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
3552
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
3553
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
3554
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
3555
- default:
3556
- {
3557
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3558
- GGML_LOG_ERROR("add template specialization for this size\n");
3559
- GGML_ABORT("add template specialization for this size");
3560
- }
4451
+ if (ne00 == 192 && ne20 == 128) {
4452
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
4453
+ } else if (ne00 == 576 && ne20 == 512) {
4454
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
4455
+ } else {
4456
+ switch (ne00) {
4457
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
4458
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
4459
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
4460
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
4461
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
4462
+ case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
4463
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
4464
+ default:
4465
+ {
4466
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4467
+ GGML_LOG_ERROR("add template specialization for this size\n");
4468
+ GGML_ABORT("add template specialization for this size");
4469
+ }
4470
+ }
3561
4471
  }
3562
4472
  } break;
3563
4473
  case GGML_TYPE_Q4_1:
3564
4474
  {
3565
- switch (ne00) {
3566
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
3567
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
3568
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
3569
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
3570
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
3571
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
3572
- default:
3573
- {
3574
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3575
- GGML_LOG_ERROR("add template specialization for this size\n");
3576
- GGML_ABORT("add template specialization for this size");
3577
- }
4475
+ if (ne00 == 192 && ne20 == 128) {
4476
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
4477
+ } else if (ne00 == 576 && ne20 == 512) {
4478
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
4479
+ } else {
4480
+ switch (ne00) {
4481
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
4482
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
4483
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
4484
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
4485
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
4486
+ case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
4487
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
4488
+ default:
4489
+ {
4490
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4491
+ GGML_LOG_ERROR("add template specialization for this size\n");
4492
+ GGML_ABORT("add template specialization for this size");
4493
+ }
4494
+ }
3578
4495
  }
3579
4496
  } break;
3580
4497
  case GGML_TYPE_Q5_0:
3581
4498
  {
3582
- switch (ne00) {
3583
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
3584
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
3585
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
3586
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
3587
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
3588
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
3589
- default:
3590
- {
3591
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3592
- GGML_LOG_ERROR("add template specialization for this size\n");
3593
- GGML_ABORT("add template specialization for this size");
3594
- }
4499
+ if (ne00 == 192 && ne20 == 128) {
4500
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
4501
+ } else if (ne00 == 576 && ne20 == 512) {
4502
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
4503
+ } else {
4504
+ switch (ne00) {
4505
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
4506
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
4507
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
4508
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
4509
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
4510
+ case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
4511
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
4512
+ default:
4513
+ {
4514
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4515
+ GGML_LOG_ERROR("add template specialization for this size\n");
4516
+ GGML_ABORT("add template specialization for this size");
4517
+ }
4518
+ }
3595
4519
  }
3596
4520
  } break;
3597
4521
  case GGML_TYPE_Q5_1:
3598
4522
  {
3599
- switch (ne00) {
3600
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
3601
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
3602
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
3603
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
3604
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
3605
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
3606
- default:
3607
- {
3608
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3609
- GGML_LOG_ERROR("add template specialization for this size\n");
3610
- GGML_ABORT("add template specialization for this size");
3611
- }
4523
+ if (ne00 == 192 && ne20 == 128) {
4524
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
4525
+ } else if (ne00 == 576 && ne20 == 512) {
4526
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
4527
+ } else {
4528
+ switch (ne00) {
4529
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
4530
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
4531
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
4532
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
4533
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
4534
+ case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
4535
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
4536
+ default:
4537
+ {
4538
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4539
+ GGML_LOG_ERROR("add template specialization for this size\n");
4540
+ GGML_ABORT("add template specialization for this size");
4541
+ }
4542
+ }
3612
4543
  }
3613
4544
  } break;
3614
4545
  case GGML_TYPE_Q8_0:
3615
4546
  {
3616
- switch (ne00) {
3617
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
3618
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
3619
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
3620
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
3621
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
3622
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
3623
- default:
3624
- {
3625
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3626
- GGML_LOG_ERROR("add template specialization for this size\n");
3627
- GGML_ABORT("add template specialization for this size");
3628
- }
4547
+ if (ne00 == 192 && ne20 == 128) {
4548
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
4549
+ } else if (ne00 == 576 && ne20 == 512) {
4550
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
4551
+ } else {
4552
+ switch (ne00) {
4553
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
4554
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
4555
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
4556
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
4557
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
4558
+ case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
4559
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
4560
+ default:
4561
+ {
4562
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4563
+ GGML_LOG_ERROR("add template specialization for this size\n");
4564
+ GGML_ABORT("add template specialization for this size");
4565
+ }
4566
+ }
3629
4567
  }
3630
4568
  } break;
3631
4569
  default:
@@ -3639,6 +4577,42 @@ static void ggml_metal_encode_node(
3639
4577
  use_vec_kernel = true;
3640
4578
 
3641
4579
  switch (ne00) {
4580
+ case 64:
4581
+ {
4582
+ switch (src1->type) {
4583
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
4584
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
4585
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
4586
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
4587
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
4588
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
4589
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
4590
+ default:
4591
+ {
4592
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4593
+ GGML_LOG_ERROR("add template specialization for this type\n");
4594
+ GGML_ABORT("add template specialization for this type");
4595
+ }
4596
+ }
4597
+ } break;
4598
+ case 96:
4599
+ {
4600
+ switch (src1->type) {
4601
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
4602
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
4603
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
4604
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
4605
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
4606
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
4607
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
4608
+ default:
4609
+ {
4610
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4611
+ GGML_LOG_ERROR("add template specialization for this type\n");
4612
+ GGML_ABORT("add template specialization for this type");
4613
+ }
4614
+ }
4615
+ } break;
3642
4616
  case 128:
3643
4617
  {
3644
4618
  switch (src1->type) {
@@ -3657,6 +4631,42 @@ static void ggml_metal_encode_node(
3657
4631
  }
3658
4632
  }
3659
4633
  } break;
4634
+ case 192:
4635
+ {
4636
+ if (ne20 == 128) {
4637
+ switch (src1->type) {
4638
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
4639
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
4640
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
4641
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
4642
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
4643
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
4644
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
4645
+ default:
4646
+ {
4647
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4648
+ GGML_LOG_ERROR("add template specialization for this type\n");
4649
+ GGML_ABORT("add template specialization for this type");
4650
+ }
4651
+ }
4652
+ } else {
4653
+ switch (src1->type) {
4654
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
4655
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
4656
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
4657
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
4658
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
4659
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
4660
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
4661
+ default:
4662
+ {
4663
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4664
+ GGML_LOG_ERROR("add template specialization for this type\n");
4665
+ GGML_ABORT("add template specialization for this type");
4666
+ }
4667
+ }
4668
+ }
4669
+ } break;
3660
4670
  case 256:
3661
4671
  {
3662
4672
  switch (src1->type) {
@@ -3675,12 +4685,36 @@ static void ggml_metal_encode_node(
3675
4685
  }
3676
4686
  }
3677
4687
  } break;
4688
+ case 576:
4689
+ {
4690
+ if (ne20 == 512) {
4691
+ switch (src1->type) {
4692
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
4693
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
4694
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
4695
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
4696
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
4697
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
4698
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
4699
+ default:
4700
+ {
4701
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4702
+ GGML_LOG_ERROR("add template specialization for this type\n");
4703
+ GGML_ABORT("add template specialization for this type");
4704
+ }
4705
+ }
4706
+ } else {
4707
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
4708
+ GGML_LOG_ERROR("add template specialization for this size\n");
4709
+ GGML_ABORT("add template specialization for this size");
4710
+ }
4711
+ } break;
3678
4712
  default:
3679
- {
3680
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3681
- GGML_LOG_ERROR("add template specialization for this size\n");
3682
- GGML_ABORT("add template specialization for this size");
3683
- }
4713
+ {
4714
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4715
+ GGML_LOG_ERROR("add template specialization for this size\n");
4716
+ GGML_ABORT("add template specialization for this size");
4717
+ }
3684
4718
  }
3685
4719
  }
3686
4720
 
@@ -3694,9 +4728,12 @@ static void ggml_metal_encode_node(
3694
4728
  /*.ne11 =*/ ne11,
3695
4729
  /*.ne_12_2 =*/ ne12,
3696
4730
  /*.ne_12_3 =*/ ne13,
3697
- /*.nb_12_1 =*/ nb11,
3698
- /*.nb_12_2 =*/ nb12,
3699
- /*.nb_12_3 =*/ nb13,
4731
+ /*.nb11 =*/ nb11,
4732
+ /*.nb12 =*/ nb12,
4733
+ /*.nb13 =*/ nb13,
4734
+ /*.nb21 =*/ nb21,
4735
+ /*.nb22 =*/ nb22,
4736
+ /*.nb23 =*/ nb23,
3700
4737
  /*.nb31 =*/ nb31,
3701
4738
  /*.ne1 =*/ ne1,
3702
4739
  /*.ne2 =*/ ne2,
@@ -3775,10 +4812,9 @@ static void ggml_metal_encode_node(
3775
4812
  // ne00*(nsg)
3776
4813
  // each simdgroup has a full f16 head vector in shared mem to accumulate results
3777
4814
  //
3778
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
4815
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
3779
4816
 
3780
4817
  int64_t nsgmax = 2;
3781
-
3782
4818
  while (true) {
3783
4819
  const size_t smem = FATTN_SMEM(nsgmax);
3784
4820
  if (smem > device.maxThreadgroupMemoryLength) {
@@ -3810,10 +4846,6 @@ static void ggml_metal_encode_node(
3810
4846
  case GGML_OP_CPY:
3811
4847
  case GGML_OP_CONT:
3812
4848
  {
3813
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3814
-
3815
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
3816
-
3817
4849
  id<MTLComputePipelineState> pipeline = nil;
3818
4850
 
3819
4851
  switch (src0t) {
@@ -3847,7 +4879,47 @@ static void ggml_metal_encode_node(
3847
4879
  switch (dstt) {
3848
4880
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3849
4881
  case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3850
- default: GGML_ASSERT(false && "not implemented");
4882
+ default: GGML_ABORT("not implemented");
4883
+ };
4884
+ } break;
4885
+ case GGML_TYPE_Q4_0:
4886
+ {
4887
+ switch (dstt) {
4888
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
4889
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
4890
+ default: GGML_ABORT("not implemented");
4891
+ };
4892
+ } break;
4893
+ case GGML_TYPE_Q4_1:
4894
+ {
4895
+ switch (dstt) {
4896
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
4897
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
4898
+ default: GGML_ABORT("not implemented");
4899
+ };
4900
+ } break;
4901
+ case GGML_TYPE_Q5_0:
4902
+ {
4903
+ switch (dstt) {
4904
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
4905
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
4906
+ default: GGML_ABORT("not implemented");
4907
+ };
4908
+ } break;
4909
+ case GGML_TYPE_Q5_1:
4910
+ {
4911
+ switch (dstt) {
4912
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
4913
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
4914
+ default: GGML_ABORT("not implemented");
4915
+ };
4916
+ } break;
4917
+ case GGML_TYPE_Q8_0:
4918
+ {
4919
+ switch (dstt) {
4920
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
4921
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
4922
+ default: GGML_ABORT("not implemented");
3851
4923
  };
3852
4924
  } break;
3853
4925
  default: GGML_ABORT("not implemented");
@@ -3877,7 +4949,11 @@ static void ggml_metal_encode_node(
3877
4949
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3878
4950
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3879
4951
 
4952
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4953
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4954
+
3880
4955
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4956
+
3881
4957
  } break;
3882
4958
  case GGML_OP_SET:
3883
4959
  {
@@ -3982,21 +5058,24 @@ static void ggml_metal_encode_node(
3982
5058
  const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
3983
5059
  const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
3984
5060
 
3985
- // TODO: add ggml_metal_kargs struct
5061
+ ggml_metal_kargs_pool_2d args_pool_2d = {
5062
+ /* .k0 = */ k0,
5063
+ /* .k1 = */ k1,
5064
+ /* .s0 = */ s0,
5065
+ /* .s1 = */ s1,
5066
+ /* .p0 = */ p0,
5067
+ /* .p1 = */ p1,
5068
+ /* .IH = */ IH,
5069
+ /* .IW = */ IW,
5070
+ /* .OH = */ OH,
5071
+ /* .OW = */ OW,
5072
+ /* .parallel_elements = */ parallel_elements
5073
+ };
5074
+
3986
5075
  [encoder setComputePipelineState:pipeline];
3987
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3988
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3989
- [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
3990
- [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
3991
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
3992
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
3993
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
3994
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
3995
- [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
3996
- [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
3997
- [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
3998
- [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
3999
- [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
5076
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
5077
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
5078
+ [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
4000
5079
 
4001
5080
  [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
4002
5081
  } break;
@@ -4031,6 +5110,8 @@ static void ggml_metal_encode_node(
4031
5110
  GGML_ABORT("fatal error");
4032
5111
  }
4033
5112
  }
5113
+
5114
+ return true;
4034
5115
  }
4035
5116
 
4036
5117
  static enum ggml_status ggml_metal_graph_compute(
@@ -4084,25 +5165,25 @@ static enum ggml_status ggml_metal_graph_compute(
4084
5165
  }
4085
5166
 
4086
5167
  // the main thread commits the first few commands immediately
4087
- // command_buffer[n_cb]
5168
+ // cmd_buf[n_cb]
4088
5169
  {
4089
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4090
- ctx->command_buffers[n_cb] = command_buffer;
5170
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5171
+ ctx->cmd_bufs[n_cb].obj = cmd_buf;
4091
5172
 
4092
- [command_buffer enqueue];
5173
+ [cmd_buf enqueue];
4093
5174
  ctx->encode_async(n_cb);
4094
5175
  }
4095
5176
 
4096
5177
  // prepare the rest of the command buffers asynchronously
4097
- // command_buffer[0.. n_cb)
5178
+ // cmd_buf[0.. n_cb)
4098
5179
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
4099
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4100
- ctx->command_buffers[cb_idx] = command_buffer;
5180
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5181
+ ctx->cmd_bufs[cb_idx].obj = cmd_buf;
4101
5182
 
4102
5183
  // always enqueue the first two command buffers
4103
5184
  // enqueue all of the command buffers if we don't need to abort
4104
5185
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
4105
- [command_buffer enqueue];
5186
+ [cmd_buf enqueue];
4106
5187
  }
4107
5188
  }
4108
5189
 
@@ -4111,14 +5192,14 @@ static enum ggml_status ggml_metal_graph_compute(
4111
5192
  // wait for completion and check status of each command buffer
4112
5193
  // needed to detect if the device ran out-of-memory for example (#1881)
4113
5194
  {
4114
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
4115
- [command_buffer waitUntilCompleted];
5195
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
5196
+ [cmd_buf waitUntilCompleted];
4116
5197
 
4117
- MTLCommandBufferStatus status = [command_buffer status];
5198
+ MTLCommandBufferStatus status = [cmd_buf status];
4118
5199
  if (status != MTLCommandBufferStatusCompleted) {
4119
5200
  GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
4120
5201
  if (status == MTLCommandBufferStatusError) {
4121
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5202
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4122
5203
  }
4123
5204
 
4124
5205
  return GGML_STATUS_FAILED;
@@ -4126,20 +5207,20 @@ static enum ggml_status ggml_metal_graph_compute(
4126
5207
  }
4127
5208
 
4128
5209
  for (int i = 0; i < n_cb; ++i) {
4129
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
4130
- [command_buffer waitUntilCompleted];
5210
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
5211
+ [cmd_buf waitUntilCompleted];
4131
5212
 
4132
- MTLCommandBufferStatus status = [command_buffer status];
5213
+ MTLCommandBufferStatus status = [cmd_buf status];
4133
5214
  if (status != MTLCommandBufferStatusCompleted) {
4134
5215
  GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
4135
5216
  if (status == MTLCommandBufferStatusError) {
4136
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5217
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4137
5218
  }
4138
5219
 
4139
5220
  return GGML_STATUS_FAILED;
4140
5221
  }
4141
5222
 
4142
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
5223
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
4143
5224
  if (!next_buffer) {
4144
5225
  continue;
4145
5226
  }
@@ -4176,6 +5257,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
4176
5257
  for (int i = 0; i < ctx->n_buffers; i++) {
4177
5258
  [ctx->buffers[i].metal release];
4178
5259
  }
5260
+
5261
+ ggml_backend_metal_buffer_rset_free(ctx);
4179
5262
  ggml_backend_metal_device_rel(buffer->buft->device->context);
4180
5263
 
4181
5264
  if (ctx->owned) {
@@ -4198,19 +5281,19 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
4198
5281
  static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
4199
5282
  memset((char *)tensor->data + offset, value, size);
4200
5283
 
4201
- UNUSED(buffer);
5284
+ GGML_UNUSED(buffer);
4202
5285
  }
4203
5286
 
4204
5287
  static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
4205
5288
  memcpy((char *)tensor->data + offset, data, size);
4206
5289
 
4207
- UNUSED(buffer);
5290
+ GGML_UNUSED(buffer);
4208
5291
  }
4209
5292
 
4210
5293
  static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
4211
5294
  memcpy(data, (const char *)tensor->data + offset, size);
4212
5295
 
4213
- UNUSED(buffer);
5296
+ GGML_UNUSED(buffer);
4214
5297
  }
4215
5298
 
4216
5299
  static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
@@ -4220,7 +5303,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c
4220
5303
  }
4221
5304
  return false;
4222
5305
 
4223
- UNUSED(buffer);
5306
+ GGML_UNUSED(buffer);
4224
5307
  }
4225
5308
 
4226
5309
  static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4246,7 +5329,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
4246
5329
  static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
4247
5330
  return "Metal";
4248
5331
 
4249
- UNUSED(buft);
5332
+ GGML_UNUSED(buft);
4250
5333
  }
4251
5334
 
4252
5335
  static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
@@ -4270,8 +5353,8 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t s
4270
5353
  }
4271
5354
  #endif
4272
5355
  #endif
4273
- UNUSED(device);
4274
- UNUSED(size_aligned);
5356
+ GGML_UNUSED(device);
5357
+ GGML_UNUSED(size_aligned);
4275
5358
  }
4276
5359
 
4277
5360
  static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -4284,7 +5367,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
4284
5367
  size_aligned += (size_page - (size_aligned % size_page));
4285
5368
  }
4286
5369
 
4287
- id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
5370
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
5371
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
4288
5372
 
4289
5373
  ctx->all_data = ggml_metal_host_malloc(size_aligned);
4290
5374
  ctx->all_size = size_aligned;
@@ -4307,7 +5391,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
4307
5391
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
4308
5392
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
4309
5393
  free(ctx);
4310
- ggml_backend_metal_device_rel(buft->device->context);
5394
+ ggml_backend_metal_device_rel(ctx_dev);
5395
+ return NULL;
5396
+ }
5397
+
5398
+ if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5399
+ GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5400
+ free(ctx);
5401
+ ggml_backend_metal_device_rel(ctx_dev);
4311
5402
  return NULL;
4312
5403
  }
4313
5404
 
@@ -4318,7 +5409,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
4318
5409
 
4319
5410
  static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4320
5411
  return 32;
4321
- UNUSED(buft);
5412
+ GGML_UNUSED(buft);
4322
5413
  }
4323
5414
 
4324
5415
  static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
@@ -4328,13 +5419,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty
4328
5419
 
4329
5420
  return max_size;
4330
5421
 
4331
- UNUSED(buft);
5422
+ GGML_UNUSED(buft);
4332
5423
  }
4333
5424
 
4334
5425
  static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
4335
5426
  return true;
4336
5427
 
4337
- UNUSED(buft);
5428
+ GGML_UNUSED(buft);
4338
5429
  }
4339
5430
 
4340
5431
  ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
@@ -4357,7 +5448,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
4357
5448
  static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
4358
5449
  return "Metal_Mapped";
4359
5450
 
4360
- UNUSED(buft);
5451
+ GGML_UNUSED(buft);
4361
5452
  }
4362
5453
 
4363
5454
  static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
@@ -4400,7 +5491,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
4400
5491
  size_aligned += (size_page - (size_aligned % size_page));
4401
5492
  }
4402
5493
 
4403
- id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
5494
+ struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
5495
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
4404
5496
 
4405
5497
  // the buffer fits into the max buffer size allowed by the device
4406
5498
  if (size_aligned <= device.maxBufferLength) {
@@ -4453,6 +5545,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
4453
5545
  }
4454
5546
  }
4455
5547
 
5548
+ if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5549
+ GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5550
+ free(ctx);
5551
+ ggml_backend_metal_device_rel(ctx_dev);
5552
+ return NULL;
5553
+ }
5554
+
4456
5555
  return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
4457
5556
  }
4458
5557
 
@@ -4461,7 +5560,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
4461
5560
  static const char * ggml_backend_metal_name(ggml_backend_t backend) {
4462
5561
  return "Metal";
4463
5562
 
4464
- UNUSED(backend);
5563
+ GGML_UNUSED(backend);
4465
5564
  }
4466
5565
 
4467
5566
  static void ggml_backend_metal_free(ggml_backend_t backend) {
@@ -4504,8 +5603,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
4504
5603
 
4505
5604
  const int n_nodes_per_cb = ctx->n_nodes_per_cb;
4506
5605
 
4507
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
4508
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
5606
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
5607
+
5608
+ id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
4509
5609
 
4510
5610
  int node_start = 0;
4511
5611
  int node_end = n_nodes_0;
@@ -4517,22 +5617,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
4517
5617
 
4518
5618
  const bool should_capture = ctx->capture_next_compute;
4519
5619
 
5620
+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5621
+ ggml_metal_mem_pool_reset(mem_pool);
5622
+
4520
5623
  for (int idx = node_start; idx < node_end; ++idx) {
4521
5624
  if (should_capture) {
4522
5625
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
4523
5626
  }
4524
5627
 
4525
- ggml_metal_encode_node(backend, idx, encoder);
5628
+ const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
4526
5629
 
4527
5630
  if (should_capture) {
4528
5631
  [encoder popDebugGroup];
4529
5632
  }
5633
+
5634
+ if (!res) {
5635
+ break;
5636
+ }
4530
5637
  }
4531
5638
 
4532
5639
  [encoder endEncoding];
4533
5640
 
4534
5641
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
4535
- [command_buffer commit];
5642
+ [cmd_buf commit];
4536
5643
  }
4537
5644
  });
4538
5645
  }
@@ -4766,6 +5873,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
4766
5873
  }
4767
5874
  }
4768
5875
 
5876
+ if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5877
+ GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5878
+ free(ctx);
5879
+ ggml_backend_metal_device_rel(ctx_dev);
5880
+ return NULL;
5881
+ }
5882
+
4769
5883
  return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
4770
5884
  }
4771
5885
 
@@ -4779,7 +5893,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml
4779
5893
  return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
4780
5894
  buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
4781
5895
 
4782
- UNUSED(dev);
5896
+ GGML_UNUSED(dev);
4783
5897
  }
4784
5898
 
4785
5899
  static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {